Merge tag 'b7207' into dev-hexagon-refactoring-part2
This commit is contained in:
commit
e5e6994700
16
ci/run.sh
16
ci/run.sh
|
|
@ -45,7 +45,7 @@ sd=`dirname $0`
|
|||
cd $sd/../
|
||||
SRC=`pwd`
|
||||
|
||||
CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON"
|
||||
CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_SCHED_NO_REALLOC=ON"
|
||||
|
||||
if [ ! -z ${GG_BUILD_METAL} ]; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON"
|
||||
|
|
@ -428,10 +428,10 @@ function gg_run_qwen3_0_6b {
|
|||
|
||||
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -ngl 99 -c 1024 -b 512 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
|
||||
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
|
||||
function check_ppl {
|
||||
qnt="$1"
|
||||
|
|
@ -523,8 +523,8 @@ function gg_run_embd_bge_small {
|
|||
|
||||
./bin/llama-quantize ${model_f16} ${model_q8_0} q8_0
|
||||
|
||||
(time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
|
||||
(time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
|
||||
(time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
|
||||
(time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
|
||||
|
||||
set +e
|
||||
}
|
||||
|
|
@ -564,7 +564,7 @@ function gg_run_rerank_tiny {
|
|||
model_f16="${path_models}/ggml-model-f16.gguf"
|
||||
|
||||
# for this model, the SEP token is "</s>"
|
||||
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
|
||||
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --no-op-offload --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
|
||||
|
||||
# sample output
|
||||
# rerank score 0: 0.029
|
||||
|
|
|
|||
|
|
@ -980,7 +980,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params & params) {
|
||||
params.kv_unified = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_KV_SPLIT"));
|
||||
).set_env("LLAMA_ARG_KV_UNIFIED"));
|
||||
add_opt(common_arg(
|
||||
{"--no-context-shift"},
|
||||
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
|
||||
|
|
@ -2639,7 +2639,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params &, const std::string & value) {
|
||||
common_log_set_file(common_log_main(), value.c_str());
|
||||
}
|
||||
));
|
||||
).set_env("LLAMA_LOG_FILE"));
|
||||
add_opt(common_arg(
|
||||
{"--log-colors"}, "[on|off|auto]",
|
||||
"Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
|
||||
|
|
|
|||
|
|
@ -13,6 +13,120 @@
|
|||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder,
|
||||
const common_regex & prefix,
|
||||
size_t rstrip_prefix = 0) {
|
||||
static const std::vector<std::vector<std::string>> args_paths = { { "arguments" } };
|
||||
if (auto res = builder.try_find_regex(prefix)) {
|
||||
builder.move_back(rstrip_prefix);
|
||||
auto tool_calls = builder.consume_json_with_dumped_args(args_paths);
|
||||
if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call array");
|
||||
}
|
||||
} else {
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
}
|
||||
|
||||
static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) {
|
||||
std::string arguments;
|
||||
if (builder.is_partial()) {
|
||||
arguments = (json{
|
||||
{ "code", code + builder.healing_marker() }
|
||||
})
|
||||
.dump();
|
||||
auto idx = arguments.find(builder.healing_marker());
|
||||
if (idx != std::string::npos) {
|
||||
arguments.resize(idx);
|
||||
}
|
||||
} else {
|
||||
arguments = (json{
|
||||
{ "code", code }
|
||||
})
|
||||
.dump();
|
||||
}
|
||||
return arguments;
|
||||
}
|
||||
|
||||
/**
|
||||
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
||||
* Aggregates the prefix, suffix and in-between text into the content.
|
||||
*/
|
||||
static void parse_json_tool_calls(
|
||||
common_chat_msg_parser & builder,
|
||||
const std::optional<common_regex> & block_open,
|
||||
const std::optional<common_regex> & function_regex_start_only,
|
||||
const std::optional<common_regex> & function_regex,
|
||||
const common_regex & close_regex,
|
||||
const std::optional<common_regex> & block_close,
|
||||
bool allow_raw_python = false,
|
||||
const std::function<std::string(const common_chat_msg_parser::find_regex_result & fres)> & get_function_name =
|
||||
nullptr) {
|
||||
auto parse_tool_calls = [&]() {
|
||||
size_t from = std::string::npos;
|
||||
auto first = true;
|
||||
while (true) {
|
||||
auto start_pos = builder.pos();
|
||||
auto res = function_regex_start_only && first ? builder.try_consume_regex(*function_regex_start_only) :
|
||||
function_regex ? builder.try_find_regex(*function_regex, from) :
|
||||
std::nullopt;
|
||||
|
||||
if (res) {
|
||||
std::string name;
|
||||
if (get_function_name) {
|
||||
name = get_function_name(*res);
|
||||
} else {
|
||||
GGML_ASSERT(res->groups.size() == 2);
|
||||
name = builder.str(res->groups[1]);
|
||||
}
|
||||
first = false;
|
||||
if (name.empty()) {
|
||||
// get_function_name signalled us that we should skip this match and treat it as content.
|
||||
from = res->groups[0].begin + 1;
|
||||
continue;
|
||||
}
|
||||
from = std::string::npos;
|
||||
|
||||
auto maybe_raw_python = name == "python" && allow_raw_python;
|
||||
if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) {
|
||||
if (auto arguments = builder.try_consume_json_with_dumped_args({ {} })) {
|
||||
if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
builder.consume_regex(close_regex);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (maybe_raw_python) {
|
||||
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
|
||||
if (!builder.add_tool_call(name, "", arguments)) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
return;
|
||||
}
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
} else {
|
||||
builder.move_to(start_pos);
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (block_close) {
|
||||
builder.consume_regex(*block_close);
|
||||
}
|
||||
builder.consume_spaces();
|
||||
builder.add_content(builder.consume_rest());
|
||||
};
|
||||
if (block_open) {
|
||||
if (auto res = builder.try_find_regex(*block_open)) {
|
||||
parse_tool_calls();
|
||||
} else {
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
} else {
|
||||
parse_tool_calls();
|
||||
}
|
||||
}
|
||||
|
||||
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
|
||||
: input_(input), is_partial_(is_partial), syntax_(syntax)
|
||||
{
|
||||
|
|
@ -532,3 +646,857 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
|
|||
void common_chat_msg_parser::clear_tools() {
|
||||
result_.tool_calls.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* All common_chat_parse_* moved from chat.cpp to chat-parser.cpp below
|
||||
* to reduce incremental compile time for parser changes.
|
||||
*/
|
||||
static void common_chat_parse_generic(common_chat_msg_parser & builder) {
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
static const std::vector<std::vector<std::string>> content_paths = {
|
||||
{"response"},
|
||||
};
|
||||
static const std::vector<std::vector<std::string>> args_paths = {
|
||||
{"tool_call", "arguments"},
|
||||
{"tool_calls", "arguments"},
|
||||
};
|
||||
auto data = builder.consume_json_with_dumped_args(args_paths, content_paths);
|
||||
if (data.value.contains("tool_calls")) {
|
||||
if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool calls");
|
||||
}
|
||||
} else if (data.value.contains("tool_call")) {
|
||||
if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
} else if (data.value.contains("response")) {
|
||||
const auto & response = data.value.at("response");
|
||||
builder.add_content(response.is_string() ? response.template get<std::string>() : response.dump(2));
|
||||
if (data.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete response");
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON");
|
||||
}
|
||||
}
|
||||
|
||||
static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
|
||||
parse_prefixed_json_tool_call_array(builder, prefix);
|
||||
}
|
||||
|
||||
static void common_chat_parse_magistral(common_chat_msg_parser & builder) {
|
||||
builder.try_parse_reasoning("[THINK]", "[/THINK]");
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
|
||||
parse_prefixed_json_tool_call_array(builder, prefix);
|
||||
}
|
||||
|
||||
static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
|
||||
builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>");
|
||||
|
||||
static const common_regex start_action_regex("<\\|START_ACTION\\|>");
|
||||
static const common_regex end_action_regex("<\\|END_ACTION\\|>");
|
||||
static const common_regex start_response_regex("<\\|START_RESPONSE\\|>");
|
||||
static const common_regex end_response_regex("<\\|END_RESPONSE\\|>");
|
||||
|
||||
if (auto res = builder.try_find_regex(start_action_regex)) {
|
||||
// If we didn't extract thoughts, prelude includes them.
|
||||
auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}});
|
||||
for (const auto & tool_call : tool_calls.value) {
|
||||
std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
|
||||
std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : "";
|
||||
std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : "";
|
||||
if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
}
|
||||
if (tool_calls.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
builder.consume_regex(end_action_regex);
|
||||
} else if (auto res = builder.try_find_regex(start_response_regex)) {
|
||||
if (!builder.try_find_regex(end_response_regex)) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
throw common_chat_msg_partial_exception(end_response_regex.str());
|
||||
}
|
||||
} else {
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
}
|
||||
|
||||
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
static const common_regex function_regex(
|
||||
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
|
||||
static const common_regex close_regex("\\}\\s*");
|
||||
|
||||
static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\(");
|
||||
static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*");
|
||||
|
||||
if (with_builtin_tools) {
|
||||
static const common_regex builtin_call_regex("<\\|python_tag\\|>");
|
||||
if (auto res = builder.try_find_regex(builtin_call_regex)) {
|
||||
auto fun_res = builder.consume_regex(function_name_regex);
|
||||
auto function_name = builder.str(fun_res.groups[1]);
|
||||
|
||||
common_healing_marker healing_marker;
|
||||
json args = json::object();
|
||||
while (true) {
|
||||
if (auto arg_res = builder.try_consume_regex(arg_name_regex)) {
|
||||
auto arg_name = builder.str(arg_res->groups[1]);
|
||||
auto partial = builder.consume_json();
|
||||
args[arg_name] = partial.json;
|
||||
healing_marker.marker = partial.healing_marker.marker;
|
||||
healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker;
|
||||
builder.consume_spaces();
|
||||
if (!builder.try_consume_literal(",")) {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
builder.consume_literal(")");
|
||||
builder.consume_spaces();
|
||||
|
||||
auto arguments = args.dump();
|
||||
if (!builder.add_tool_call(function_name, "", arguments)) {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
parse_json_tool_calls(
|
||||
builder,
|
||||
/* block_open= */ std::nullopt,
|
||||
/* function_regex_start_only= */ function_regex,
|
||||
/* function_regex= */ std::nullopt,
|
||||
close_regex,
|
||||
std::nullopt);
|
||||
|
||||
}
|
||||
|
||||
static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)");
|
||||
static const common_regex tool_calls_end("<|tool▁calls▁end|>");
|
||||
static const common_regex function_regex("(?:<|tool▁call▁begin|>)?function<|tool▁sep|>([^\n]+)\n```json\n");
|
||||
static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
|
||||
|
||||
parse_json_tool_calls(
|
||||
builder,
|
||||
/* block_open= */ tool_calls_begin,
|
||||
/* function_regex_start_only= */ std::nullopt,
|
||||
function_regex,
|
||||
close_regex,
|
||||
tool_calls_end);
|
||||
}
|
||||
|
||||
static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) {
|
||||
static const common_regex function_regex("(?:<|tool▁call▁begin|>)?([^\\n<]+)(?:<|tool▁sep|>)");
|
||||
|
||||
static const common_regex close_regex("(?:[\\s]*)?<|tool▁call▁end|>");
|
||||
static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)");
|
||||
static const common_regex tool_calls_end("<|tool▁calls▁end|>");
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
LOG_DBG("%s: not parse_tool_calls\n", __func__);
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
LOG_DBG("%s: parse_tool_calls\n", __func__);
|
||||
|
||||
parse_json_tool_calls(
|
||||
builder,
|
||||
/* block_open= */ tool_calls_begin,
|
||||
/* function_regex_start_only= */ std::nullopt,
|
||||
function_regex,
|
||||
close_regex,
|
||||
tool_calls_end);
|
||||
}
|
||||
|
||||
static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
|
||||
// DeepSeek V3.1 outputs reasoning content between "<think>" and "</think>" tags, followed by regular content
|
||||
// First try to parse using the standard reasoning parsing method
|
||||
LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str());
|
||||
|
||||
auto start_pos = builder.pos();
|
||||
auto found_end_think = builder.try_find_literal("</think>");
|
||||
builder.move_to(start_pos);
|
||||
|
||||
if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) {
|
||||
LOG_DBG("%s: no end_think, not partial, adding content\n", __func__);
|
||||
common_chat_parse_deepseek_v3_1_content(builder);
|
||||
} else if (builder.try_parse_reasoning("<think>", "</think>")) {
|
||||
// If reasoning was parsed successfully, the remaining content is regular content
|
||||
LOG_DBG("%s: parsed reasoning, adding content\n", __func__);
|
||||
// </think><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>NAME\n```json\nJSON\n```<|tool▁call▁end|><|tool▁calls▁end|>
|
||||
common_chat_parse_deepseek_v3_1_content(builder);
|
||||
} else {
|
||||
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) {
|
||||
LOG_DBG("%s: reasoning_format none, adding content\n", __func__);
|
||||
common_chat_parse_deepseek_v3_1_content(builder);
|
||||
return;
|
||||
}
|
||||
// If no reasoning tags found, check if we should treat everything as reasoning
|
||||
if (builder.syntax().thinking_forced_open) {
|
||||
// If thinking is forced open but no tags found, treat everything as reasoning
|
||||
LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__);
|
||||
builder.add_reasoning_content(builder.consume_rest());
|
||||
} else {
|
||||
LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__);
|
||||
// <|tool▁call▁begin|>NAME<|tool▁sep|>JSON<|tool▁call▁end|>
|
||||
common_chat_parse_deepseek_v3_1_content(builder);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "<minimax:tool_call>",
|
||||
/* form.tool_start = */ "<invoke name=\"",
|
||||
/* form.tool_sep = */ "\">",
|
||||
/* form.key_start = */ "<parameter name=\"",
|
||||
/* form.key_val_sep = */ "\">",
|
||||
/* form.val_end = */ "</parameter>",
|
||||
/* form.tool_end = */ "</invoke>",
|
||||
/* form.scope_end = */ "</minimax:tool_call>",
|
||||
};
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||
}
|
||||
|
||||
static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<tool_call>";
|
||||
form.tool_start = "<function=";
|
||||
form.tool_sep = ">";
|
||||
form.key_start = "<parameter=";
|
||||
form.key_val_sep = ">";
|
||||
form.val_end = "</parameter>";
|
||||
form.tool_end = "</function>";
|
||||
form.scope_end = "</tool_call>";
|
||||
form.trim_raw_argval = true;
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form);
|
||||
}
|
||||
|
||||
static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<|tool_calls_section_begin|>";
|
||||
form.tool_start = "<|tool_call_begin|>";
|
||||
form.tool_sep = "<|tool_call_argument_begin|>{";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.tool_end = "}<|tool_call_end|>";
|
||||
form.scope_end = "<|tool_calls_section_end|>";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||
}
|
||||
|
||||
static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<tool_calls>[";
|
||||
form.tool_start = "{\"name\": \"";
|
||||
form.tool_sep = "\", \"arguments\": {";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.tool_end = "}, ";
|
||||
form.scope_end = "]</tool_calls>";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
form.last_tool_end = "}";
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<thinking>", "</thinking>");
|
||||
}
|
||||
|
||||
static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "";
|
||||
form.tool_start = "<tool_call>\n{\"name\": \"";
|
||||
form.tool_sep = "\", \"arguments\": {";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.tool_end = "}\n</tool_call>";
|
||||
form.scope_end = "";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form);
|
||||
}
|
||||
|
||||
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
|
||||
static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))";
|
||||
static const std::string recipient("(?: to=functions\\.([^<\\s]+))");
|
||||
|
||||
static const common_regex start_regex("<\\|start\\|>assistant");
|
||||
static const common_regex analysis_regex("<\\|channel\\|>analysis");
|
||||
static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?");
|
||||
static const common_regex preamble_regex("<\\|channel\\|>commentary");
|
||||
static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?");
|
||||
static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?");
|
||||
|
||||
auto consume_end = [&](bool include_end = false) {
|
||||
if (auto res = builder.try_find_literal("<|end|>")) {
|
||||
return res->prelude + (include_end ? builder.str(res->groups[0]) : "");
|
||||
}
|
||||
return builder.consume_rest();
|
||||
};
|
||||
|
||||
auto handle_tool_call = [&](const std::string & name) {
|
||||
if (auto args = builder.try_consume_json_with_dumped_args({{}})) {
|
||||
if (builder.syntax().parse_tool_calls) {
|
||||
if (!builder.add_tool_call(name, "", args->value) || args->is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
} else if (args->is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional<common_regex_match> {
|
||||
auto match = regex.search(input, 0, true);
|
||||
if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) {
|
||||
return match;
|
||||
}
|
||||
return std::nullopt;
|
||||
};
|
||||
|
||||
do {
|
||||
auto header_start_pos = builder.pos();
|
||||
auto content_start = builder.try_find_literal("<|message|>");
|
||||
if (!content_start) {
|
||||
throw common_chat_msg_partial_exception("incomplete header");
|
||||
}
|
||||
|
||||
auto header = content_start->prelude;
|
||||
|
||||
if (auto match = regex_match(tool_call1_regex, header)) {
|
||||
auto group = match->groups[1];
|
||||
auto name = header.substr(group.begin, group.end - group.begin);
|
||||
handle_tool_call(name);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto match = regex_match(tool_call2_regex, header)) {
|
||||
auto group = match->groups[2];
|
||||
auto name = header.substr(group.begin, group.end - group.begin);
|
||||
handle_tool_call(name);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (regex_match(analysis_regex, header)) {
|
||||
builder.move_to(header_start_pos);
|
||||
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
|
||||
builder.add_content(consume_end(true));
|
||||
} else {
|
||||
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) {
|
||||
builder.add_content(consume_end());
|
||||
continue;
|
||||
}
|
||||
|
||||
// Possibly a malformed message, attempt to recover by rolling
|
||||
// back to pick up the next <|start|>
|
||||
LOG_DBG("%s: unknown header from message: %s\n", __func__, header.c_str());
|
||||
builder.move_to(header_start_pos);
|
||||
} while (builder.try_find_regex(start_regex, std::string::npos, false));
|
||||
|
||||
auto remaining = builder.consume_rest();
|
||||
if (!remaining.empty()) {
|
||||
LOG_DBG("%s: content after last message: %s\n", __func__, remaining.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "",
|
||||
/* form.tool_start = */ "<tool_call>",
|
||||
/* form.tool_sep = */ "",
|
||||
/* form.key_start = */ "<arg_key>",
|
||||
/* form.key_val_sep = */ "</arg_key>",
|
||||
/* form.val_end = */ "</arg_value>",
|
||||
/* form.tool_end = */ "</tool_call>",
|
||||
/* form.scope_end = */ "",
|
||||
/* form.key_val_sep2 = */ "<arg_value>",
|
||||
};
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||
}
|
||||
|
||||
static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) {
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
static const common_regex prefix(regex_escape(" functools["));
|
||||
parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1);
|
||||
}
|
||||
|
||||
static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) {
|
||||
static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))");
|
||||
static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))");
|
||||
static const common_regex close_regex(R"(\s*)");
|
||||
|
||||
parse_json_tool_calls(
|
||||
builder,
|
||||
std::nullopt,
|
||||
function_regex_start_only,
|
||||
function_regex,
|
||||
close_regex,
|
||||
std::nullopt,
|
||||
/* allow_raw_python= */ true,
|
||||
/* get_function_name= */ [&](const auto & res) -> std::string {
|
||||
auto at_start = res.groups[0].begin == 0;
|
||||
auto name = builder.str(res.groups[1]);
|
||||
if (!name.empty() && name.back() == '{') {
|
||||
// Unconsume the opening brace '{' to ensure the JSON parsing goes well.
|
||||
builder.move_back(1);
|
||||
}
|
||||
auto idx = name.find_last_not_of("\n{");
|
||||
name = name.substr(0, idx + 1);
|
||||
if (at_start && name == "all") {
|
||||
return "";
|
||||
}
|
||||
return name;
|
||||
});
|
||||
}
|
||||
|
||||
static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) {
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
||||
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));
|
||||
|
||||
static const common_regex function_regex(R"(<function=(\w+)>)");
|
||||
static const common_regex close_regex(R"(</function>)");
|
||||
|
||||
parse_json_tool_calls(
|
||||
builder,
|
||||
/* block_open= */ std::nullopt,
|
||||
/* function_regex_start_only= */ std::nullopt,
|
||||
function_regex,
|
||||
close_regex,
|
||||
std::nullopt);
|
||||
|
||||
if (auto res = builder.try_find_regex(python_tag_regex)) {
|
||||
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
|
||||
builder.add_tool_call("python", "", arguments);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
static const common_regex open_regex(
|
||||
"(?:"
|
||||
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
|
||||
"(" // match 2 (open_tag)
|
||||
"<tool_call>"
|
||||
"|<function_call>"
|
||||
"|<tool>"
|
||||
"|<tools>"
|
||||
"|<response>"
|
||||
"|<json>"
|
||||
"|<xml>"
|
||||
"|<JSON>"
|
||||
")?"
|
||||
"(\\s*\\{\\s*\"name\")" // match 3 (named tool call)
|
||||
")"
|
||||
"|<function=([^>]+)>" // match 4 (function name)
|
||||
"|<function name=\"([^\"]+)\">" // match 5 (function name again)
|
||||
);
|
||||
|
||||
while (auto res = builder.try_find_regex(open_regex)) {
|
||||
const auto & block_start = res->groups[1];
|
||||
std::string block_end = block_start.empty() ? "" : "```";
|
||||
|
||||
const auto & open_tag = res->groups[2];
|
||||
std::string close_tag;
|
||||
|
||||
if (!res->groups[3].empty()) {
|
||||
builder.move_to(res->groups[3].begin);
|
||||
close_tag = open_tag.empty() ? "" : "</" + builder.str(open_tag).substr(1);
|
||||
|
||||
if (auto tool_call = builder.try_consume_json_with_dumped_args({{"arguments"}})) {
|
||||
if (!builder.add_tool_call(tool_call->value) || tool_call->is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
builder.consume_spaces();
|
||||
builder.consume_literal(close_tag);
|
||||
builder.consume_spaces();
|
||||
if (!block_end.empty()) {
|
||||
builder.consume_literal(block_end);
|
||||
builder.consume_spaces();
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("failed to parse tool call");
|
||||
}
|
||||
} else {
|
||||
auto function_name = builder.str(res->groups[4]);
|
||||
if (function_name.empty()) {
|
||||
function_name = builder.str(res->groups[5]);
|
||||
}
|
||||
GGML_ASSERT(!function_name.empty());
|
||||
|
||||
close_tag = "</function>";
|
||||
|
||||
if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
|
||||
if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
builder.consume_spaces();
|
||||
builder.consume_literal(close_tag);
|
||||
builder.consume_spaces();
|
||||
if (!block_end.empty()) {
|
||||
builder.consume_literal(block_end);
|
||||
builder.consume_spaces();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
static void common_chat_parse_granite(common_chat_msg_parser & builder) {
|
||||
// Parse thinking tags
|
||||
static const common_regex start_think_regex(regex_escape("<think>"));
|
||||
static const common_regex end_think_regex(regex_escape("</think>"));
|
||||
// Granite models output partial tokens such as "<" and "<think".
|
||||
// By leveraging try_consume_regex()/try_find_regex() throwing
|
||||
// common_chat_msg_partial_exception for these partial tokens,
|
||||
// processing is interrupted and the tokens are not passed to add_content().
|
||||
if (auto res = builder.try_consume_regex(start_think_regex)) {
|
||||
// Restore position for try_parse_reasoning()
|
||||
builder.move_to(res->groups[0].begin);
|
||||
builder.try_find_regex(end_think_regex, std::string::npos, false);
|
||||
// Restore position for try_parse_reasoning()
|
||||
builder.move_to(res->groups[0].begin);
|
||||
}
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
|
||||
// Parse response tags
|
||||
static const common_regex start_response_regex(regex_escape("<response>"));
|
||||
static const common_regex end_response_regex(regex_escape("</response>"));
|
||||
// Granite models output partial tokens such as "<" and "<response".
|
||||
// Same hack as reasoning parsing.
|
||||
if (builder.try_consume_regex(start_response_regex)) {
|
||||
builder.try_find_regex(end_response_regex);
|
||||
}
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
// Look for tool calls
|
||||
static const common_regex tool_call_regex(regex_escape("<|tool_call|>"));
|
||||
if (auto res = builder.try_find_regex(tool_call_regex)) {
|
||||
builder.move_to(res->groups[0].end);
|
||||
|
||||
// Expect JSON array of tool calls
|
||||
if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) {
|
||||
if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
}
|
||||
|
||||
static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) {
|
||||
// Parse thinking tags
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
// Look for tool calls
|
||||
static const common_regex tool_call_regex(regex_escape("<TOOLCALL>"));
|
||||
if (auto res = builder.try_find_regex(tool_call_regex)) {
|
||||
builder.move_to(res->groups[0].end);
|
||||
|
||||
// Expect JSON array of tool calls
|
||||
auto tool_calls_data = builder.consume_json();
|
||||
if (tool_calls_data.json.is_array()) {
|
||||
if (!builder.try_consume_literal("</TOOLCALL>")) {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
builder.add_tool_calls(tool_calls_data.json);
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
}
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
static void common_chat_parse_apertus(common_chat_msg_parser & builder) {
|
||||
// Parse thinking tags
|
||||
builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>");
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
// Look for tool calls
|
||||
static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>"));
|
||||
if (auto res = builder.try_find_regex(tool_call_regex)) {
|
||||
builder.move_to(res->groups[0].end);
|
||||
|
||||
auto tool_calls_data = builder.consume_json();
|
||||
if (tool_calls_data.json.is_array()) {
|
||||
builder.consume_spaces();
|
||||
if (!builder.try_consume_literal("<|tools_suffix|>")) {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
for (const auto & value : tool_calls_data.json) {
|
||||
if (value.is_object()) {
|
||||
builder.add_tool_call_short_form(value);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
}
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
|
||||
static void common_chat_parse_lfm2(common_chat_msg_parser & builder) {
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
// LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|>
|
||||
static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>"));
|
||||
static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>"));
|
||||
|
||||
// Loop through all tool calls
|
||||
while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) {
|
||||
builder.move_to(res->groups[0].end);
|
||||
|
||||
// Parse JSON array format: [{"name": "...", "arguments": {...}}]
|
||||
auto tool_calls_data = builder.consume_json();
|
||||
|
||||
// Consume end marker
|
||||
builder.consume_spaces();
|
||||
if (!builder.try_consume_regex(tool_call_end_regex)) {
|
||||
throw common_chat_msg_partial_exception("Expected <|tool_call_end|>");
|
||||
}
|
||||
|
||||
// Process each tool call in the array
|
||||
if (tool_calls_data.json.is_array()) {
|
||||
for (const auto & tool_call : tool_calls_data.json) {
|
||||
if (!tool_call.is_object()) {
|
||||
throw common_chat_msg_partial_exception("Tool call must be an object");
|
||||
}
|
||||
|
||||
if (!tool_call.contains("name")) {
|
||||
throw common_chat_msg_partial_exception("Tool call missing 'name' field");
|
||||
}
|
||||
|
||||
std::string function_name = tool_call.at("name");
|
||||
std::string arguments = "{}";
|
||||
|
||||
if (tool_call.contains("arguments")) {
|
||||
if (tool_call.at("arguments").is_object()) {
|
||||
arguments = tool_call.at("arguments").dump();
|
||||
} else if (tool_call.at("arguments").is_string()) {
|
||||
arguments = tool_call.at("arguments");
|
||||
}
|
||||
}
|
||||
|
||||
if (!builder.add_tool_call(function_name, "", arguments)) {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Expected JSON array for tool calls");
|
||||
}
|
||||
|
||||
// Consume any trailing whitespace after this tool call
|
||||
builder.consume_spaces();
|
||||
}
|
||||
|
||||
// Consume any remaining content after all tool calls
|
||||
auto remaining = builder.consume_rest();
|
||||
if (!string_strip(remaining).empty()) {
|
||||
builder.add_content(remaining);
|
||||
}
|
||||
}
|
||||
|
||||
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "<seed:tool_call>",
|
||||
/* form.tool_start = */ "<function=",
|
||||
/* form.tool_sep = */ ">",
|
||||
/* form.key_start = */ "<parameter=",
|
||||
/* form.key_val_sep = */ ">",
|
||||
/* form.val_end = */ "</parameter>",
|
||||
/* form.tool_end = */ "</function>",
|
||||
/* form.scope_end = */ "</seed:tool_call>",
|
||||
};
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<seed:think>", "</seed:think>");
|
||||
}
|
||||
|
||||
static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str());
|
||||
|
||||
switch (builder.syntax().format) {
|
||||
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
|
||||
common_chat_parse_content_only(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_GENERIC:
|
||||
common_chat_parse_generic(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
|
||||
common_chat_parse_mistral_nemo(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_MAGISTRAL:
|
||||
common_chat_parse_magistral(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X:
|
||||
common_chat_parse_llama_3_1(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
|
||||
common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
|
||||
common_chat_parse_deepseek_r1(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1:
|
||||
common_chat_parse_deepseek_v3_1(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
|
||||
common_chat_parse_functionary_v3_2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
|
||||
common_chat_parse_functionary_v3_1_llama_3_1(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_HERMES_2_PRO:
|
||||
common_chat_parse_hermes_2_pro(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
|
||||
common_chat_parse_firefunction_v2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B:
|
||||
common_chat_parse_command_r7b(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_GRANITE:
|
||||
common_chat_parse_granite(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_GPT_OSS:
|
||||
common_chat_parse_gpt_oss(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_SEED_OSS:
|
||||
common_chat_parse_seed_oss(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_NEMOTRON_V2:
|
||||
common_chat_parse_nemotron_v2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_APERTUS:
|
||||
common_chat_parse_apertus(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS:
|
||||
common_chat_parse_lfm2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_MINIMAX_M2:
|
||||
common_chat_parse_minimax_m2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_GLM_4_5:
|
||||
common_chat_parse_glm_4_5(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_KIMI_K2:
|
||||
common_chat_parse_kimi_k2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML:
|
||||
common_chat_parse_qwen3_coder_xml(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_APRIEL_1_5:
|
||||
common_chat_parse_apriel_1_5(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_XIAOMI_MIMO:
|
||||
common_chat_parse_xiaomi_mimo(builder);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
|
||||
}
|
||||
builder.finish();
|
||||
}
|
||||
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
|
||||
common_chat_msg_parser builder(input, is_partial, syntax);
|
||||
try {
|
||||
common_chat_parse(builder);
|
||||
} catch (const common_chat_msg_partial_exception & ex) {
|
||||
LOG_DBG("Partial parse: %s\n", ex.what());
|
||||
if (!is_partial) {
|
||||
builder.clear_tools();
|
||||
builder.move_to(0);
|
||||
common_chat_parse_content_only(builder);
|
||||
}
|
||||
}
|
||||
auto msg = builder.result();
|
||||
if (!is_partial) {
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
|
|
|||
952
common/chat.cpp
952
common/chat.cpp
File diff suppressed because it is too large
Load Diff
|
|
@ -517,16 +517,18 @@ static bool common_pull_file(httplib::Client & cli,
|
|||
headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-");
|
||||
}
|
||||
|
||||
std::atomic<size_t> downloaded{existing_size};
|
||||
const char * func = __func__; // avoid __func__ inside a lambda
|
||||
size_t downloaded = existing_size;
|
||||
size_t progress_step = 0;
|
||||
|
||||
auto res = cli.Get(resolve_path, headers,
|
||||
[&](const httplib::Response &response) {
|
||||
if (existing_size > 0 && response.status != 206) {
|
||||
LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", __func__, response.status);
|
||||
LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", func, response.status);
|
||||
return false;
|
||||
}
|
||||
if (existing_size == 0 && response.status != 200) {
|
||||
LOG_WRN("%s: download received non-successful status code: %d\n", __func__, response.status);
|
||||
LOG_WRN("%s: download received non-successful status code: %d\n", func, response.status);
|
||||
return false;
|
||||
}
|
||||
if (total_size == 0 && response.has_header("Content-Length")) {
|
||||
|
|
@ -534,7 +536,7 @@ static bool common_pull_file(httplib::Client & cli,
|
|||
size_t content_length = std::stoull(response.get_header_value("Content-Length"));
|
||||
total_size = existing_size + content_length;
|
||||
} catch (const std::exception &e) {
|
||||
LOG_WRN("%s: invalid Content-Length header: %s\n", __func__, e.what());
|
||||
LOG_WRN("%s: invalid Content-Length header: %s\n", func, e.what());
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
|
@ -542,11 +544,16 @@ static bool common_pull_file(httplib::Client & cli,
|
|||
[&](const char *data, size_t len) {
|
||||
ofs.write(data, len);
|
||||
if (!ofs) {
|
||||
LOG_ERR("%s: error writing to file: %s\n", __func__, path_tmp.c_str());
|
||||
LOG_ERR("%s: error writing to file: %s\n", func, path_tmp.c_str());
|
||||
return false;
|
||||
}
|
||||
downloaded += len;
|
||||
print_progress(downloaded, total_size);
|
||||
progress_step += len;
|
||||
|
||||
if (progress_step >= total_size / 1000 || downloaded == total_size) {
|
||||
print_progress(downloaded, total_size);
|
||||
progress_step = 0;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
nullptr
|
||||
|
|
|
|||
|
|
@ -268,10 +268,10 @@ static bool is_reserved_name(const std::string & name) {
|
|||
}
|
||||
|
||||
std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
|
||||
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]");
|
||||
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]");
|
||||
std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
|
||||
std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
|
||||
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}
|
||||
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}, {'\\', "\\\\"}
|
||||
};
|
||||
|
||||
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
|
||||
|
|
|
|||
|
|
@ -4183,6 +4183,36 @@ class Qwen3MoeModel(Qwen2MoeModel):
|
|||
super().set_vocab()
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3NextForCausalLM")
|
||||
class Qwen3NextModel(Qwen2MoeModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3NEXT
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_ssm_conv_kernel(self.hparams["linear_conv_kernel_dim"])
|
||||
self.gguf_writer.add_ssm_state_size(self.hparams["linear_key_head_dim"])
|
||||
self.gguf_writer.add_ssm_group_count(self.hparams["linear_num_key_heads"])
|
||||
self.gguf_writer.add_ssm_time_step_rank(self.hparams["linear_num_value_heads"])
|
||||
self.gguf_writer.add_ssm_inner_size(self.hparams["linear_value_head_dim"] * self.hparams["linear_num_value_heads"])
|
||||
if (rope_dim := self.hparams.get("head_dim")) is None:
|
||||
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25)))
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name.startswith("mtp"):
|
||||
return [] # ignore MTP layers for now
|
||||
if name.endswith(".A_log"):
|
||||
data_torch = -torch.exp(data_torch)
|
||||
elif name.endswith(".dt_bias"):
|
||||
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
|
||||
elif "conv1d" in name:
|
||||
data_torch = data_torch.squeeze()
|
||||
elif name.endswith("norm.weight") and not name.endswith("linear_attn.norm.weight"):
|
||||
data_torch = data_torch + 1
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("RND1")
|
||||
class RND1Model(Qwen2MoeModel):
|
||||
model_arch = gguf.MODEL_ARCH.RND1
|
||||
|
|
|
|||
|
|
@ -42,6 +42,9 @@ The following releases are verified and recommended:
|
|||
|
||||
## News
|
||||
|
||||
- 2025.11
|
||||
- Support malloc memory on device more than 4GB.
|
||||
|
||||
- 2025.2
|
||||
- Optimize MUL_MAT Q4_0 on Intel GPU for all dGPUs and built-in GPUs since MTL. Increase the performance of LLM (llama-2-7b.Q4_0.gguf) 21%-87% on Intel GPUs (MTL, ARL-H, Arc, Flex, PVC).
|
||||
|GPU|Base tokens/s|Increased tokens/s|Percent|
|
||||
|
|
@ -789,6 +792,8 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
|||
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
|
||||
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
|
||||
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
|
||||
| UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.|
|
||||
|
||||
|
||||
|
||||
## Known Issues
|
||||
|
|
@ -835,6 +840,14 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
|||
| The default context is too big. It leads to excessive memory usage.|Set `-c 8192` or a smaller value.|
|
||||
| The model is too big and requires more memory than what is available.|Choose a smaller model or change to a smaller quantization, like Q5 -> Q4;<br>Alternatively, use more than one device to load model.|
|
||||
|
||||
- `ggml_backend_sycl_buffer_type_alloc_buffer: can't allocate 5000000000 Bytes of memory on device`
|
||||
|
||||
You need to enable to support 4GB memory malloc by:
|
||||
```
|
||||
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
|
||||
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
|
||||
```
|
||||
|
||||
### **GitHub contribution**:
|
||||
Please add the `SYCL :` prefix/tag in issues/PRs titles to help the SYCL contributors to check/address them without delay.
|
||||
|
||||
|
|
|
|||
|
|
@ -104,12 +104,16 @@ int main(int argc, char ** argv) {
|
|||
|
||||
params.embedding = true;
|
||||
|
||||
// get max number of sequences per batch
|
||||
const int n_seq_max = llama_max_parallel_sequences();
|
||||
|
||||
// if the number of prompts that would be encoded is known in advance, it's more efficient to specify the
|
||||
// --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache
|
||||
// in order to support any number of prompts
|
||||
if (params.n_parallel == 1) {
|
||||
LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__);
|
||||
params.kv_unified = true;
|
||||
params.n_parallel = n_seq_max;
|
||||
}
|
||||
|
||||
// utilize the full context
|
||||
|
|
@ -123,9 +127,6 @@ int main(int argc, char ** argv) {
|
|||
params.n_ubatch = params.n_batch;
|
||||
}
|
||||
|
||||
// get max number of sequences per batch
|
||||
const int n_seq_max = llama_max_parallel_sequences();
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
|
|
|
|||
|
|
@ -231,9 +231,9 @@ DOT = '[^\\x0A\\x0D]'
|
|||
RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])
|
||||
|
||||
INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
|
||||
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
|
||||
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\\]')
|
||||
GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]')
|
||||
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'}
|
||||
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\'}
|
||||
|
||||
NON_LITERAL_SET = set('|.()[]{}*+?')
|
||||
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?')
|
||||
|
|
|
|||
|
|
@ -4,6 +4,11 @@ set -e
|
|||
|
||||
# First try command line argument, then environment variable, then file
|
||||
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
||||
MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}"
|
||||
|
||||
if [ -z "$MODEL_TESTING_PROMPT"]; then
|
||||
MODEL_TESTING_PROMPT="Hello, my name is"
|
||||
fi
|
||||
|
||||
# Final check if we have a model path
|
||||
if [ -z "$CONVERTED_MODEL" ]; then
|
||||
|
|
@ -14,7 +19,8 @@ if [ -z "$CONVERTED_MODEL" ]; then
|
|||
fi
|
||||
|
||||
echo $CONVERTED_MODEL
|
||||
echo $MODEL_TESTING_PROMPT
|
||||
|
||||
cmake --build ../../build --target llama-logits -j8
|
||||
|
||||
../../build/bin/llama-logits -m "$CONVERTED_MODEL" "Hello, my name is"
|
||||
../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT"
|
||||
|
|
|
|||
|
|
@ -184,8 +184,12 @@ model_name = os.path.basename(model_path)
|
|||
# of using AutoModelForCausalLM.
|
||||
print(f"Model class: {model.__class__.__name__}")
|
||||
|
||||
prompt = "Hello, my name is"
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
device = next(model.parameters()).device
|
||||
if os.getenv("MODEL_TESTING_PROMPT"):
|
||||
prompt = os.getenv("MODEL_TESTING_PROMPT")
|
||||
else:
|
||||
prompt = "Hello, my name is"
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
||||
|
||||
print(f"Input tokens: {input_ids}")
|
||||
print(f"Input text: {repr(prompt)}")
|
||||
|
|
|
|||
|
|
@ -15,6 +15,9 @@ MODEL_FILE=models/llama-2-7b.Q4_0.gguf
|
|||
NGL=99
|
||||
CONTEXT=4096
|
||||
|
||||
#support malloc device memory more than 4GB.
|
||||
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
|
||||
|
||||
if [ $# -gt 0 ]; then
|
||||
GGML_SYCL_DEVICE=$1
|
||||
echo "use $GGML_SYCL_DEVICE as main GPU"
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
# If you want more control, DPC++ Allows selecting a specific device through the
|
||||
# following environment variable
|
||||
#export ONEAPI_DEVICE_SELECTOR="level_zero:0"
|
||||
export ONEAPI_DEVICE_SELECTOR="level_zero:0"
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
|
||||
#export GGML_SYCL_DEBUG=1
|
||||
|
|
@ -18,11 +18,14 @@ MODEL_FILE=models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf
|
|||
NGL=99 # Layers offloaded to the GPU. If the device runs out of memory, reduce this value according to the model you are using.
|
||||
CONTEXT=4096
|
||||
|
||||
#support malloc device memory more than 4GB.
|
||||
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
|
||||
|
||||
if [ $# -gt 0 ]; then
|
||||
GGML_SYCL_DEVICE=$1
|
||||
echo "Using $GGML_SYCL_DEVICE as the main GPU"
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
|
||||
else
|
||||
#use multiple GPUs with same max compute units
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT}
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT}
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -5,5 +5,7 @@
|
|||
set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
|
||||
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
|
||||
|
||||
:: support malloc device memory more than 4GB.
|
||||
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
|
||||
|
||||
.\build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 99 -s 0
|
||||
|
|
|
|||
|
|
@ -5,5 +5,7 @@
|
|||
set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
|
||||
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
|
||||
|
||||
:: support malloc device memory more than 4GB.
|
||||
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
|
||||
|
||||
.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -e -ngl 99
|
||||
.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -s 0 -e -ngl 99
|
||||
|
|
|
|||
|
|
@ -183,6 +183,7 @@ endif()
|
|||
# ggml core
|
||||
set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
|
||||
option(GGML_CPU "ggml: enable CPU backend" ON)
|
||||
option(GGML_SCHED_NO_REALLOC "ggml: disallow reallocations in ggml-alloc (for debugging)" OFF)
|
||||
|
||||
# 3rd party libs / backends
|
||||
option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ extern "C" {
|
|||
#endif
|
||||
|
||||
#define RPC_PROTO_MAJOR_VERSION 3
|
||||
#define RPC_PROTO_MINOR_VERSION 0
|
||||
#define RPC_PROTO_MINOR_VERSION 5
|
||||
#define RPC_PROTO_PATCH_VERSION 0
|
||||
#define GGML_RPC_MAX_SERVERS 16
|
||||
|
||||
|
|
|
|||
|
|
@ -221,6 +221,10 @@ if (GGML_BACKEND_DL)
|
|||
target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL)
|
||||
endif()
|
||||
|
||||
if (GGML_SCHED_NO_REALLOC)
|
||||
target_compile_definitions(ggml-base PUBLIC GGML_SCHED_NO_REALLOC)
|
||||
endif()
|
||||
|
||||
add_library(ggml
|
||||
ggml-backend-reg.cpp)
|
||||
add_library(ggml::ggml ALIAS ggml)
|
||||
|
|
@ -270,10 +274,13 @@ function(ggml_add_backend_library backend)
|
|||
endif()
|
||||
|
||||
# Set versioning properties for all backend libraries
|
||||
set_target_properties(${backend} PROPERTIES
|
||||
VERSION ${GGML_VERSION}
|
||||
SOVERSION ${GGML_VERSION_MAJOR}
|
||||
)
|
||||
# Building a MODULE library with a version is not supported on macOS (https://gitlab.kitware.com/cmake/cmake/-/issues/20782)
|
||||
if (NOT (APPLE AND GGML_BACKEND_DL))
|
||||
set_target_properties(${backend} PROPERTIES
|
||||
VERSION ${GGML_VERSION}
|
||||
SOVERSION ${GGML_VERSION_MAJOR}
|
||||
)
|
||||
endif()
|
||||
|
||||
if(NOT GGML_AVAILABLE_BACKENDS)
|
||||
set(GGML_AVAILABLE_BACKENDS "${backend}"
|
||||
|
|
|
|||
|
|
@ -921,10 +921,15 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
|
|||
}
|
||||
if (realloc) {
|
||||
#ifndef NDEBUG
|
||||
size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
|
||||
GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
||||
{
|
||||
size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
|
||||
if (cur_size > 0) {
|
||||
GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n",
|
||||
__func__, ggml_backend_buft_name(galloc->bufts[i]),
|
||||
cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
ggml_vbuffer_free(galloc->buffers[i]);
|
||||
galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
||||
if (galloc->buffers[i] == NULL) {
|
||||
|
|
|
|||
|
|
@ -1395,14 +1395,20 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
|
|||
|
||||
// allocate graph
|
||||
if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
|
||||
#ifdef GGML_SCHED_NO_REALLOC
|
||||
GGML_ABORT("%s: failed to allocate graph, but graph re-allocation is disabled by GGML_SCHED_NO_REALLOC\n", __func__);
|
||||
#endif
|
||||
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
|
||||
#endif
|
||||
|
||||
// the re-allocation may cause the split inputs to be moved to a different address
|
||||
// synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy
|
||||
for (int i = 0; i < sched->n_backends; i++) {
|
||||
ggml_backend_synchronize(sched->backends[i]);
|
||||
}
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
|
||||
#endif
|
||||
|
||||
ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids);
|
||||
if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate graph\n", __func__);
|
||||
|
|
|
|||
|
|
@ -1,20 +1,23 @@
|
|||
#include "ggml-backend-impl.h"
|
||||
|
||||
#if defined(__riscv) && __riscv_xlen == 64
|
||||
#include <sys/auxv.h>
|
||||
|
||||
//https://github.com/torvalds/linux/blob/master/arch/riscv/include/uapi/asm/hwcap.h#L24
|
||||
#ifndef COMPAT_HWCAP_ISA_V
|
||||
#define COMPAT_HWCAP_ISA_V (1 << ('V' - 'A'))
|
||||
#endif
|
||||
#include <asm/hwprobe.h>
|
||||
#include <asm/unistd.h>
|
||||
#include <unistd.h>
|
||||
|
||||
struct riscv64_features {
|
||||
bool has_rvv = false;
|
||||
|
||||
riscv64_features() {
|
||||
uint32_t hwcap = getauxval(AT_HWCAP);
|
||||
struct riscv_hwprobe probe;
|
||||
probe.key = RISCV_HWPROBE_KEY_IMA_EXT_0;
|
||||
probe.value = 0;
|
||||
|
||||
has_rvv = !!(hwcap & COMPAT_HWCAP_ISA_V);
|
||||
int ret = syscall(__NR_riscv_hwprobe, &probe, 1, 0, NULL, 0);
|
||||
|
||||
if (0 == ret) {
|
||||
has_rvv = !!(probe.value & RISCV_HWPROBE_IMA_V);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -9766,7 +9766,8 @@ static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params
|
|||
}
|
||||
|
||||
const float diag = A_batch[i00 * n + i00];
|
||||
GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
|
||||
assert(diag != 0.0f && "Zero diagonal in triangular matrix");
|
||||
|
||||
X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
|||
const dim3 offset_grid((nrows + block_size - 1) / block_size);
|
||||
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
|
||||
|
||||
cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream);
|
||||
CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
|
||||
|
||||
size_t temp_storage_bytes = 0;
|
||||
|
||||
|
|
|
|||
|
|
@ -21,10 +21,12 @@
|
|||
#include "ggml-common.h"
|
||||
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cfloat>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#if defined(GGML_USE_HIP)
|
||||
|
|
@ -84,12 +86,12 @@
|
|||
|
||||
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
|
||||
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
|
||||
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
|
||||
#define GGML_CUDA_CC_PH1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // MTT S5000
|
||||
|
||||
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
|
||||
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
|
||||
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
|
||||
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
|
||||
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_PH1)
|
||||
#define GGML_CUDA_CC_IS_PH1(cc) (cc >= GGML_CUDA_CC_PH1)
|
||||
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
|
||||
# define GGML_CUDA_USE_CUB
|
||||
|
|
@ -212,9 +214,9 @@ static const char * cu_get_error_str(CUresult err) {
|
|||
#define GGML_USE_VMM
|
||||
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
|
||||
|
||||
#if defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
|
||||
#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
|
||||
#define FP16_AVAILABLE
|
||||
#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
|
||||
#endif // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
|
||||
|
||||
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
|
||||
#define FAST_FP16_AVAILABLE
|
||||
|
|
@ -250,12 +252,14 @@ static const char * cu_get_error_str(CUresult err) {
|
|||
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
|
||||
|
||||
static bool fp16_available(const int cc) {
|
||||
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
|
||||
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
|
||||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
|
||||
}
|
||||
|
||||
static bool fast_fp16_available(const int cc) {
|
||||
return GGML_CUDA_CC_IS_AMD(cc) ||
|
||||
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610);
|
||||
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610) ||
|
||||
(GGML_CUDA_CC_IS_MTHREADS(cc) && fp16_available(cc));
|
||||
}
|
||||
|
||||
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
||||
|
|
@ -272,7 +276,9 @@ static bool fp16_mma_hardware_available(const int cc) {
|
|||
}
|
||||
|
||||
static bool bf16_mma_hardware_available(const int cc) {
|
||||
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
|
||||
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) ||
|
||||
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 ||
|
||||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
|
||||
}
|
||||
|
||||
static bool fp32_mma_hardware_available(const int cc) {
|
||||
|
|
@ -558,8 +564,12 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v
|
|||
acc += v.y*u.y;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
|
||||
#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
|
||||
#define V_DOT2_F32_F16_AVAILABLE
|
||||
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
|
||||
|
||||
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
|
||||
#else
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
|
|
@ -571,7 +581,7 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v,
|
|||
acc += tmpv.x * tmpu.x;
|
||||
acc += tmpv.y * tmpu.y;
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA))
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) {
|
||||
|
|
@ -972,6 +982,154 @@ struct ggml_cuda_graph {
|
|||
#endif
|
||||
};
|
||||
|
||||
struct ggml_cuda_concurrent_event {
|
||||
std::vector<cudaEvent_t> join_events;
|
||||
cudaEvent_t fork_event = nullptr;
|
||||
|
||||
int n_streams = 0;
|
||||
std::unordered_map<const ggml_tensor *, int> stream_mapping;
|
||||
|
||||
const ggml_tensor * join_node;
|
||||
|
||||
ggml_cuda_concurrent_event() = default;
|
||||
|
||||
ggml_cuda_concurrent_event(const ggml_cuda_concurrent_event &) = delete;
|
||||
ggml_cuda_concurrent_event & operator=(const ggml_cuda_concurrent_event &) = delete;
|
||||
|
||||
explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) {
|
||||
join_events.resize(n_streams);
|
||||
|
||||
for (size_t i = 0; i < join_events.size(); ++i) {
|
||||
CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming));
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming));
|
||||
}
|
||||
|
||||
ggml_cuda_concurrent_event(ggml_cuda_concurrent_event && other) noexcept
|
||||
: join_events(std::move(other.join_events))
|
||||
, fork_event(other.fork_event)
|
||||
, n_streams(other.n_streams)
|
||||
, stream_mapping(std::move(other.stream_mapping))
|
||||
, join_node(other.join_node) {
|
||||
other.fork_event = nullptr;
|
||||
}
|
||||
|
||||
// 1. check if any branches write to overlapping memory ranges (except the join node)
|
||||
// 2. check whether all srcs are either within the branch or outside the nodes covered by ggml_cuda_concurrent_event
|
||||
// we assume all nodes have the same buffer
|
||||
bool is_valid() const {
|
||||
std::vector<std::vector<std::pair<int64_t, int64_t>>> write_ranges;
|
||||
write_ranges.resize(n_streams);
|
||||
|
||||
// get join_node's memory range to exclude from overlap checking.
|
||||
// multiple nodes can use join_node's buffer; we synchronize on the join node.
|
||||
const ggml_tensor * join_t = join_node->view_src ? join_node->view_src : join_node;
|
||||
const int64_t join_start = (int64_t) join_t->data;
|
||||
const int64_t join_end = join_start + ggml_nbytes(join_t);
|
||||
|
||||
for (const auto & [tensor, stream] : stream_mapping) {
|
||||
const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
|
||||
const int64_t t_start = (int64_t) t->data;
|
||||
const int64_t t_end = t_start + ggml_nbytes(t);
|
||||
|
||||
// skip tensors that overlap with join_node's buffer.
|
||||
if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// concurrent streams begin from 1
|
||||
write_ranges[stream - 1].emplace_back(t_start, t_end);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_streams; ++i) {
|
||||
// sorts first by start then by end of write range
|
||||
std::sort(write_ranges[i].begin(), write_ranges[i].end());
|
||||
}
|
||||
|
||||
bool writes_overlap = false;
|
||||
bool dependent_srcs = false;
|
||||
for (const auto & [tensor, stream] : stream_mapping) {
|
||||
const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
|
||||
const int64_t t_start = (int64_t) t->data;
|
||||
const int64_t t_end = t_start + ggml_nbytes(t);
|
||||
|
||||
// skip tensors that overlap with join_node's buffer
|
||||
if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// check if this buffer's write data overlaps with another stream's
|
||||
std::pair<int64_t, int64_t> data_range = std::make_pair(t_start, t_end);
|
||||
for (int i = 0; i < n_streams; ++i) {
|
||||
if (i == stream - 1) {
|
||||
continue;
|
||||
}
|
||||
auto it = std::lower_bound(write_ranges[i].begin(), write_ranges[i].end(), data_range);
|
||||
|
||||
if (it != write_ranges[i].end()) {
|
||||
const std::pair<int64_t, int64_t> & other = *it;
|
||||
|
||||
// std::lower_bound returns the first element where other >= data_range (lexicographically).
|
||||
// This guarantees other.first >= data_range.first.
|
||||
// Therefore, overlap occurs iff other.first < data_range.second
|
||||
// (i.e., the other range starts before this range ends).
|
||||
if (other.first < data_range.second) {
|
||||
GGML_LOG_DEBUG("Writes overlap for %s", tensor->name);
|
||||
writes_overlap = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//check if all srcs are either in branch or don't have a branch
|
||||
for (int i = 0; i < GGML_MAX_SRC; ++i) {
|
||||
if (!tensor->src[i]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto it = stream_mapping.find(tensor->src[i]);
|
||||
|
||||
if (it == stream_mapping.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (it->second != stream) {
|
||||
dependent_srcs = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (dependent_srcs || writes_overlap) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return !writes_overlap && !dependent_srcs;
|
||||
}
|
||||
|
||||
~ggml_cuda_concurrent_event() {
|
||||
if (fork_event != nullptr) {
|
||||
CUDA_CHECK(cudaEventDestroy(fork_event));
|
||||
}
|
||||
for (cudaEvent_t e : join_events) {
|
||||
if (e != nullptr) {
|
||||
CUDA_CHECK(cudaEventDestroy(e));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_cuda_stream_context {
|
||||
std::vector<const ggml_tensor *> original_nodes;
|
||||
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;
|
||||
|
||||
void reset() {
|
||||
original_nodes.clear();
|
||||
concurrent_events.clear();
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_backend_cuda_context {
|
||||
int device;
|
||||
std::string name;
|
||||
|
|
@ -982,11 +1140,15 @@ struct ggml_backend_cuda_context {
|
|||
|
||||
std::unique_ptr<ggml_cuda_graph> cuda_graph;
|
||||
|
||||
int curr_stream_no = 0;
|
||||
|
||||
explicit ggml_backend_cuda_context(int device) :
|
||||
device(device),
|
||||
name(GGML_CUDA_NAME + std::to_string(device)) {
|
||||
}
|
||||
|
||||
ggml_cuda_stream_context concurrent_stream_context;
|
||||
|
||||
~ggml_backend_cuda_context();
|
||||
|
||||
cudaStream_t stream(int device, int stream) {
|
||||
|
|
@ -997,9 +1159,9 @@ struct ggml_backend_cuda_context {
|
|||
return streams[device][stream];
|
||||
}
|
||||
|
||||
cudaStream_t stream() {
|
||||
return stream(device, 0);
|
||||
}
|
||||
cudaStream_t stream() { return stream(device, curr_stream_no); }
|
||||
|
||||
ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; }
|
||||
|
||||
cublasHandle_t cublas_handle(int device) {
|
||||
if (cublas_handles[device] == nullptr) {
|
||||
|
|
@ -1015,15 +1177,15 @@ struct ggml_backend_cuda_context {
|
|||
}
|
||||
|
||||
// pool
|
||||
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
|
||||
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];
|
||||
|
||||
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
|
||||
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);
|
||||
|
||||
ggml_cuda_pool & pool(int device) {
|
||||
if (pools[device] == nullptr) {
|
||||
pools[device] = new_pool_for_device(device);
|
||||
if (pools[device][curr_stream_no] == nullptr) {
|
||||
pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no);
|
||||
}
|
||||
return *pools[device];
|
||||
return *pools[device][curr_stream_no];
|
||||
}
|
||||
|
||||
ggml_cuda_pool & pool() {
|
||||
|
|
|
|||
|
|
@ -86,6 +86,9 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11,
|
||||
nb12, nb13);
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
|
||||
|
|
@ -202,7 +205,7 @@ static void ggml_cpy_scalar_cuda(
|
|||
ne00n = ne00;
|
||||
ne01n = ne01;
|
||||
ne02n = ne02;
|
||||
} else if (nb00 > nb02) {
|
||||
} else {
|
||||
ne00n = ne00;
|
||||
ne01n = ne01*ne02;
|
||||
ne02n = 1;
|
||||
|
|
|
|||
|
|
@ -55,11 +55,11 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
|
|||
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
||||
#else
|
||||
ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
||||
#endif // FP16_AVAILABLE
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -609,7 +609,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
|
|||
float KQ_sum_add = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
|
||||
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ?
|
||||
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < static_cast<uint32_t>(k_VKQ_sup) ?
|
||||
expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
|
||||
KQ_sum_add += val;
|
||||
tmp[i0/(np*warp_size)][jc1] = val;
|
||||
|
|
|
|||
|
|
@ -86,11 +86,11 @@ static __global__ void flash_attn_ext_vec(
|
|||
|
||||
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
|
||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
|
||||
#else
|
||||
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
|
|
@ -112,13 +112,13 @@ static __global__ void flash_attn_ext_vec(
|
|||
|
||||
constexpr int ne_KQ = ncols*D;
|
||||
constexpr int ne_combine = nwarps*V_cols_per_iter*D;
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
||||
__shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
||||
#else
|
||||
float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
||||
__shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
|
||||
float KQ_max[ncols];
|
||||
float KQ_sum[ncols];
|
||||
|
|
@ -129,11 +129,11 @@ static __global__ void flash_attn_ext_vec(
|
|||
}
|
||||
|
||||
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
|
||||
#else
|
||||
float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
||||
float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
||||
if constexpr (Q_q8_1) {
|
||||
|
|
@ -155,7 +155,7 @@ static __global__ void flash_attn_ext_vec(
|
|||
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
if (i0 + WARP_SIZE <= D/sizeof(int) || i < D/sizeof(int)) {
|
||||
if (i0 + WARP_SIZE <= int(D/sizeof(int)) || i < int(D/sizeof(int))) {
|
||||
tmp_q_i32[i] = 0;
|
||||
}
|
||||
}
|
||||
|
|
@ -191,7 +191,7 @@ static __global__ void flash_attn_ext_vec(
|
|||
|
||||
__syncthreads();
|
||||
} else {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
const half2 scale_h2 = make_half2(scale, scale);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
|
|
@ -233,7 +233,7 @@ static __global__ void flash_attn_ext_vec(
|
|||
Q_reg[j][k].y *= scale;
|
||||
}
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
}
|
||||
|
||||
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
||||
|
|
@ -272,7 +272,7 @@ static __global__ void flash_attn_ext_vec(
|
|||
|
||||
KQ_max_new[j] = fmaxf(KQ_max_new[j], sum);
|
||||
|
||||
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == i_KQ_0) {
|
||||
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) {
|
||||
KQ_reg[j] = sum;
|
||||
}
|
||||
}
|
||||
|
|
@ -291,7 +291,7 @@ static __global__ void flash_attn_ext_vec(
|
|||
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
|
||||
KQ[j*nthreads + tid] = KQ_reg[j];
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||
|
|
@ -303,7 +303,7 @@ static __global__ void flash_attn_ext_vec(
|
|||
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
||||
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
}
|
||||
|
||||
#ifndef GGML_USE_HIP
|
||||
|
|
@ -314,7 +314,7 @@ static __global__ void flash_attn_ext_vec(
|
|||
for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
|
||||
const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
half2 KQ_k[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
|
|
@ -353,7 +353,7 @@ static __global__ void flash_attn_ext_vec(
|
|||
}
|
||||
}
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -374,7 +374,7 @@ static __global__ void flash_attn_ext_vec(
|
|||
|
||||
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||
|
|
@ -386,7 +386,7 @@ static __global__ void flash_attn_ext_vec(
|
|||
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
||||
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -421,7 +421,7 @@ static __global__ void flash_attn_ext_vec(
|
|||
const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
|
||||
KQ_max[j_VKQ] = kqmax_new;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
|
||||
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
|
||||
|
||||
|
|
@ -452,7 +452,7 @@ static __global__ void flash_attn_ext_vec(
|
|||
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
|
||||
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // V_DOT2_F32_F16_AVAILABLE
|
||||
|
||||
KQ_sum[j_VKQ] *= kqmax_scale;
|
||||
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
|
||||
|
|
|
|||
|
|
@ -522,7 +522,8 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
|||
};
|
||||
#endif // defined(GGML_USE_VMM)
|
||||
|
||||
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
|
||||
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device,
|
||||
[[maybe_unused]] int stream_no) {
|
||||
#if defined(GGML_USE_VMM)
|
||||
if (ggml_cuda_info().devices[device].vmm) {
|
||||
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
|
||||
|
|
@ -3050,7 +3051,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
|
||||
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
|
||||
|
||||
if (ops.size() == topk_moe_ops_with_norm.size() &&
|
||||
const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
|
||||
const std::initializer_list<enum ggml_op> & list2) {
|
||||
return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
|
||||
};
|
||||
|
||||
if (is_equal(topk_moe_ops_with_norm, ops) &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
|
||||
|
|
@ -3060,8 +3066,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||
}
|
||||
}
|
||||
|
||||
if (ops.size() == topk_moe_ops.size() &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
|
||||
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
|
||||
|
|
@ -3069,7 +3074,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||
}
|
||||
}
|
||||
|
||||
if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
|
||||
if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
|
||||
|
|
@ -3085,9 +3090,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||
std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
|
||||
std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
|
||||
|
||||
if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) ||
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) {
|
||||
|
||||
if ((is_equal(mul_mat_bias_glu_ops, ops) || is_equal(mul_mat_id_bias_glu_ops, ops)) &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 4 })) {
|
||||
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1];
|
||||
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2];
|
||||
|
|
@ -3099,9 +3103,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||
}
|
||||
}
|
||||
|
||||
if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) ||
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) {
|
||||
|
||||
if ((is_equal(mul_mat_id_glu_ops, ops) || is_equal(mul_mat_glu_ops, ops)) &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
|
||||
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1];
|
||||
const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
|
||||
|
|
@ -3111,7 +3114,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||
}
|
||||
}
|
||||
|
||||
if (ops.size() == 3 && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
|
||||
std::initializer_list<enum ggml_op> rope_set_rows_ops = { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS };
|
||||
|
||||
if (is_equal(rope_set_rows_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
|
||||
const ggml_tensor * rope = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * view = cgraph->nodes[node_idx + 1];
|
||||
const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2];
|
||||
|
|
@ -3196,18 +3201,83 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||
// flag used to determine whether it is an integrated_gpu
|
||||
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
|
||||
|
||||
ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context();
|
||||
bool is_concurrent_event_active = false;
|
||||
ggml_cuda_concurrent_event * concurrent_event = nullptr;
|
||||
bool should_launch_concurrent_events = false;
|
||||
|
||||
const auto try_launch_concurrent_event = [&](const ggml_tensor * node) {
|
||||
if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) {
|
||||
concurrent_event = &stream_ctx.concurrent_events[node];
|
||||
|
||||
is_concurrent_event_active = true;
|
||||
|
||||
GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name);
|
||||
|
||||
cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0
|
||||
GGML_ASSERT(cuda_ctx->curr_stream_no == 0);
|
||||
CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream));
|
||||
|
||||
for (int i = 1; i <= concurrent_event->n_streams; ++i) {
|
||||
cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i);
|
||||
CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
while (!graph_evaluated_or_captured) {
|
||||
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
|
||||
// With the use of CUDA graphs, the execution will be performed by the graph launch.
|
||||
if (!use_cuda_graph || cuda_graph_update_required) {
|
||||
|
||||
[[maybe_unused]] int prev_i = 0;
|
||||
|
||||
if (stream_ctx.concurrent_events.size() > 0) {
|
||||
should_launch_concurrent_events = true;
|
||||
for (const auto & [tensor, event] : stream_ctx.concurrent_events) {
|
||||
should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid();
|
||||
}
|
||||
}
|
||||
if (should_launch_concurrent_events) {
|
||||
//Restore the original graph to enable fusion within the streams
|
||||
cgraph->nodes = const_cast<ggml_tensor **>(stream_ctx.original_nodes.data());
|
||||
cgraph->n_nodes = (int) stream_ctx.original_nodes.size();
|
||||
}
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
if (is_concurrent_event_active) {
|
||||
GGML_ASSERT(concurrent_event);
|
||||
|
||||
if (node == concurrent_event->join_node) {
|
||||
cuda_ctx->curr_stream_no = 0;
|
||||
for (int i = 1; i <= concurrent_event->n_streams; ++i) {
|
||||
// Wait on join events of forked streams in the main stream
|
||||
CUDA_CHECK(cudaEventRecord(concurrent_event->join_events[i - 1],
|
||||
cuda_ctx->stream(cuda_ctx->device, i)));
|
||||
CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->join_events[i - 1]));
|
||||
}
|
||||
|
||||
is_concurrent_event_active = false;
|
||||
concurrent_event = nullptr;
|
||||
} else {
|
||||
GGML_ASSERT (concurrent_event->stream_mapping.find(node) != concurrent_event->stream_mapping.end());
|
||||
cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
|
||||
GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
|
||||
}
|
||||
} else if (i - prev_i > 1) {
|
||||
//the previous node was fused
|
||||
const ggml_tensor * prev_node = cgraph->nodes[i - 1];
|
||||
try_launch_concurrent_event(prev_node);
|
||||
|
||||
if (is_concurrent_event_active) {
|
||||
cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
|
||||
GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
|
||||
}
|
||||
}
|
||||
prev_i = i;
|
||||
|
||||
#ifdef GGML_CUDA_DEBUG
|
||||
const int nodes_fused = i - prev_i - 1;
|
||||
prev_i = i;
|
||||
if (nodes_fused > 0) {
|
||||
GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused);
|
||||
}
|
||||
|
|
@ -3217,6 +3287,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||
continue;
|
||||
}
|
||||
|
||||
|
||||
// start of fusion operations
|
||||
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
|
||||
if (!disable_fusion) {
|
||||
|
||||
|
|
@ -3509,13 +3581,17 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||
}
|
||||
#else
|
||||
GGML_UNUSED(integrated);
|
||||
#endif // NDEBUG
|
||||
#endif // NDEBUG
|
||||
|
||||
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
|
||||
if (!ok) {
|
||||
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||
}
|
||||
GGML_ASSERT(ok);
|
||||
|
||||
if (!is_concurrent_event_active) {
|
||||
try_launch_concurrent_event(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3655,6 +3731,235 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
|
||||
|
||||
static bool enable_graph_optimization = [] {
|
||||
const char * env = getenv("GGML_CUDA_GRAPH_OPT");
|
||||
return env != nullptr && atoi(env) == 1;
|
||||
}();
|
||||
|
||||
if (!enable_graph_optimization) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(ggml_backend_cuda_get_device_count() == 1 && "compute graph optimization is only supported on single GPU in the CUDA backend");
|
||||
GGML_LOG_DEBUG("Optimizing CUDA graph %p with %d nodes\n", cgraph->nodes, cgraph->n_nodes);
|
||||
|
||||
ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context();
|
||||
stream_context.reset();
|
||||
|
||||
// number of out-degrees for a particular node
|
||||
std::unordered_map<const ggml_tensor *, int> fan_out;
|
||||
// reverse mapping of node to index in the cgraph
|
||||
std::unordered_map<const ggml_tensor *, int> node_indices;
|
||||
|
||||
const auto & is_noop = [](const ggml_tensor * node) -> bool {
|
||||
return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE ||
|
||||
node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
|
||||
};
|
||||
|
||||
const auto & depends_on = [](const ggml_tensor * dst, const ggml_tensor * src) -> bool {
|
||||
for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
|
||||
if (dst->src[s] == src) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// implicit dependency if they view the same tensor
|
||||
const ggml_tensor * dst2 = dst->view_src ? dst->view_src : dst;
|
||||
const ggml_tensor * src2 = src->view_src ? src->view_src : src;
|
||||
if (dst2 == src2) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {
|
||||
const ggml_tensor * node = cgraph->nodes[node_idx];
|
||||
node_indices[node] = node_idx;
|
||||
|
||||
if (is_noop(node)) {
|
||||
continue;
|
||||
}
|
||||
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
|
||||
const ggml_tensor * src = cgraph->nodes[node_idx]->src[src_idx];
|
||||
//TODO: check why nrows > 1 fails
|
||||
if (node && !is_noop(node) && ggml_nrows(node) <= 1) {
|
||||
fan_out[src] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Target Q, K, V for concurrency
|
||||
// this is a more general way to find nodes which can be candidates for concurrency (although it has not been tested for anything else):
|
||||
// 1. find fan-out (fork) nodes where the same input is used at least N times (in QKV, it would be "attn-norm")
|
||||
// 2. find the join node, where 2 or more of the outputs are required (in QKV, this would "KQ" or "flash-attn")
|
||||
// 3. account for all branches from the fork to the join
|
||||
// 4. To extend lifetimes of the tensors, we interleave the branches (see below for more details)
|
||||
// 5. save the original cgraph and restore it in graph_compute, to enable fusion within streams
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/16991#issuecomment-3522620030
|
||||
|
||||
const int min_fan_out = 3;
|
||||
const int max_fan_out = 3;
|
||||
|
||||
// store {fork_idx, join_idx}
|
||||
std::vector<std::pair<int, int>> concurrent_node_ranges;
|
||||
|
||||
// save the original nodes
|
||||
std::vector<const ggml_tensor *> original_nodes;
|
||||
original_nodes.reserve(cgraph->n_nodes);
|
||||
for (int i = 0; i < cgraph->n_nodes; ++i) {
|
||||
original_nodes.push_back(cgraph->nodes[i]);
|
||||
}
|
||||
cuda_ctx->stream_context().original_nodes = std::move(original_nodes);
|
||||
|
||||
for (const auto & [root_node, count] : fan_out) {
|
||||
if (count >= min_fan_out && count <= max_fan_out) {
|
||||
const int root_node_idx = node_indices[root_node];
|
||||
|
||||
bool is_part_of_event = false;
|
||||
for (const auto & [start, end] : concurrent_node_ranges) {
|
||||
if (root_node_idx >= start && root_node_idx <= end) {
|
||||
is_part_of_event = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_part_of_event) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<std::vector<const ggml_tensor *>> nodes_per_branch;
|
||||
for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
|
||||
const ggml_tensor * node = cgraph->nodes[i];
|
||||
if (!is_noop(node) && depends_on(node, root_node)) {
|
||||
nodes_per_branch.push_back({ node });
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(nodes_per_branch.size() == (size_t) count);
|
||||
|
||||
//find the join point
|
||||
const ggml_tensor * join_node = nullptr;
|
||||
|
||||
const auto & belongs_to_branch = [&](const ggml_tensor * node,
|
||||
const std::vector<const ggml_tensor *> & branch) -> bool {
|
||||
for (const ggml_tensor * n : branch) {
|
||||
if (depends_on(node, n)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
|
||||
const ggml_tensor * curr_node = cgraph->nodes[i];
|
||||
|
||||
int num_joins = 0;
|
||||
for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
|
||||
if (belongs_to_branch(curr_node, nodes_per_branch[branch_idx])) {
|
||||
num_joins++;
|
||||
}
|
||||
}
|
||||
|
||||
if (num_joins >= 2) {
|
||||
join_node = curr_node;
|
||||
break;
|
||||
}
|
||||
|
||||
bool found_branch = false;
|
||||
for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
|
||||
std::vector<const ggml_tensor *> & branch_vec = nodes_per_branch[branch_idx];
|
||||
if (belongs_to_branch(curr_node, branch_vec)) {
|
||||
//continue accumulating
|
||||
if (std::find(branch_vec.begin(), branch_vec.end(), curr_node) == branch_vec.end()) {
|
||||
branch_vec.push_back(curr_node);
|
||||
}
|
||||
found_branch = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!found_branch && is_noop(curr_node)) {
|
||||
// we can put it in any branch because it will be ignored
|
||||
nodes_per_branch[0].push_back({ curr_node });
|
||||
}
|
||||
}
|
||||
|
||||
if (join_node) {
|
||||
//Create ggml_cuda_concurrent_event
|
||||
ggml_cuda_concurrent_event concurrent_event(nodes_per_branch.size());
|
||||
concurrent_event.join_node = join_node;
|
||||
|
||||
for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
|
||||
for (const ggml_tensor * n : nodes_per_branch[branch_idx]) {
|
||||
concurrent_event.stream_mapping[n] = branch_idx + 1;
|
||||
}
|
||||
}
|
||||
|
||||
int fork_node_idx = node_indices[root_node];
|
||||
int join_node_idx = node_indices[join_node];
|
||||
|
||||
int current_branch_idx = 0;
|
||||
int current_node_idx = fork_node_idx + 1;
|
||||
const int n_branches = nodes_per_branch.size();
|
||||
|
||||
int total_branch_nodes = 0;
|
||||
for (std::vector<const ggml_tensor *> branch_nodes : nodes_per_branch) {
|
||||
total_branch_nodes += branch_nodes.size();
|
||||
}
|
||||
|
||||
// there are other nodes in the middle which are unaccounted for
|
||||
// usually (cpy) nodes, then ignore this fork
|
||||
if (join_node_idx - fork_node_idx - 1 != total_branch_nodes) {
|
||||
GGML_LOG_DEBUG(
|
||||
"Skipping %s because the number of nodes in the middle is not equal to the total number of "
|
||||
"branch nodes %d != %d\n",
|
||||
root_node->name, join_node_idx - fork_node_idx - 1, total_branch_nodes);
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context().concurrent_events;
|
||||
GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end());
|
||||
concurrent_events.emplace(root_node, std::move(concurrent_event));
|
||||
GGML_LOG_DEBUG("Adding stream at node %s %p\n", root_node->name, root_node);
|
||||
concurrent_node_ranges.emplace_back(fork_node_idx, join_node_idx);
|
||||
|
||||
// interleave tensors to extend lifetimes so that ggml graph doesn't recycle them
|
||||
// example transformation:
|
||||
// [attn-norm, QMul, QNorm, QRope, KMul, KNorm, KRope, VMul, attn] ->
|
||||
// [attn-norm, QMul, KMul, VMul, QNorm, VNorm, QRope, KRope, attn]
|
||||
while (current_node_idx < join_node_idx) {
|
||||
std::vector<const ggml_tensor *> & branch_nodes = nodes_per_branch[current_branch_idx];
|
||||
|
||||
bool has_node = false;
|
||||
for (std::vector<const ggml_tensor *> branch_node : nodes_per_branch) {
|
||||
has_node |= branch_node.size() > 0;
|
||||
}
|
||||
|
||||
GGML_ASSERT(has_node);
|
||||
|
||||
if (branch_nodes.empty()) {
|
||||
current_branch_idx = (current_branch_idx + 1) % n_branches;
|
||||
continue;
|
||||
}
|
||||
|
||||
cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
|
||||
current_node_idx++;
|
||||
branch_nodes.erase(branch_nodes.begin());
|
||||
|
||||
// append all empty nodes
|
||||
while (!branch_nodes.empty() && is_noop(branch_nodes.front())) {
|
||||
cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
|
||||
current_node_idx++;
|
||||
branch_nodes.erase(branch_nodes.begin());
|
||||
}
|
||||
|
||||
current_branch_idx = (current_branch_idx + 1) % n_branches;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static const ggml_backend_i ggml_backend_cuda_interface = {
|
||||
/* .get_name = */ ggml_backend_cuda_get_name,
|
||||
/* .free = */ ggml_backend_cuda_free,
|
||||
|
|
@ -3669,7 +3974,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
|
|||
/* .graph_compute = */ ggml_backend_cuda_graph_compute,
|
||||
/* .event_record = */ ggml_backend_cuda_event_record,
|
||||
/* .event_wait = */ ggml_backend_cuda_event_wait,
|
||||
/* .graph_optimize = */ NULL,
|
||||
/* .graph_optimize = */ ggml_backend_cuda_graph_optimize,
|
||||
};
|
||||
|
||||
static ggml_guid_t ggml_backend_cuda_guid() {
|
||||
|
|
|
|||
|
|
@ -889,8 +889,8 @@ namespace ggml_cuda_mma {
|
|||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
|
||||
#else
|
||||
tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D;
|
||||
tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A;
|
||||
tile <16, 8, float> * D16 = reinterpret_cast<tile <16, 8, float> *>(&D);
|
||||
const tile<16, 8, half2> * A16 = reinterpret_cast<const tile<16, 8, half2> *>(&A);
|
||||
mma(D16[0], A16[0], B);
|
||||
mma(D16[1], A16[1], B);
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
|
|
|
|||
|
|
@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
|||
return false;
|
||||
}
|
||||
} else {
|
||||
if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||
if (src1_ncols > 16) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@
|
|||
#define cudaStreamNonBlocking hipStreamNonBlocking
|
||||
#define cudaStreamPerThread hipStreamPerThread
|
||||
#define cudaStreamSynchronize hipStreamSynchronize
|
||||
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
|
||||
#define cudaStreamWaitEvent hipStreamWaitEvent
|
||||
#define cudaGraphExec_t hipGraphExec_t
|
||||
#define cudaGraphNode_t hipGraphNode_t
|
||||
#define cudaKernelNodeParams hipKernelNodeParams
|
||||
|
|
|
|||
|
|
@ -106,6 +106,7 @@ enum rpc_cmd {
|
|||
RPC_CMD_GET_ALLOC_SIZE,
|
||||
RPC_CMD_HELLO,
|
||||
RPC_CMD_DEVICE_COUNT,
|
||||
RPC_CMD_GRAPH_RECOMPUTE,
|
||||
RPC_CMD_COUNT,
|
||||
};
|
||||
|
||||
|
|
@ -205,10 +206,6 @@ struct rpc_msg_copy_tensor_rsp {
|
|||
uint8_t result;
|
||||
};
|
||||
|
||||
struct rpc_msg_graph_compute_rsp {
|
||||
uint8_t result;
|
||||
};
|
||||
|
||||
struct rpc_msg_get_device_memory_req {
|
||||
uint32_t device;
|
||||
};
|
||||
|
|
@ -217,6 +214,11 @@ struct rpc_msg_get_device_memory_rsp {
|
|||
uint64_t free_mem;
|
||||
uint64_t total_mem;
|
||||
};
|
||||
|
||||
struct rpc_msg_graph_recompute_req {
|
||||
uint32_t device;
|
||||
};
|
||||
|
||||
#pragma pack(pop)
|
||||
|
||||
// RPC data structures
|
||||
|
|
@ -234,10 +236,35 @@ struct ggml_backend_rpc_buffer_type_context {
|
|||
size_t max_size;
|
||||
};
|
||||
|
||||
struct graph_cache {
|
||||
|
||||
bool is_cached(const ggml_cgraph * cgraph) {
|
||||
if ((int)last_graph.size() != cgraph->n_nodes) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void add(const ggml_cgraph * cgraph) {
|
||||
last_graph.resize(cgraph->n_nodes);
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ggml_tensor> last_graph;
|
||||
};
|
||||
|
||||
struct ggml_backend_rpc_context {
|
||||
std::string endpoint;
|
||||
uint32_t device;
|
||||
std::string name;
|
||||
graph_cache gc;
|
||||
};
|
||||
|
||||
struct ggml_backend_rpc_buffer_context {
|
||||
|
|
@ -815,13 +842,24 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve
|
|||
|
||||
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
||||
std::vector<uint8_t> input;
|
||||
serialize_graph(rpc_ctx->device, cgraph, input);
|
||||
rpc_msg_graph_compute_rsp response;
|
||||
auto sock = get_socket(rpc_ctx->endpoint);
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
|
||||
RPC_STATUS_ASSERT(status);
|
||||
return (enum ggml_status)response.result;
|
||||
|
||||
GGML_ASSERT(cgraph->n_nodes > 0);
|
||||
bool reuse = rpc_ctx->gc.is_cached(cgraph);
|
||||
if (reuse) {
|
||||
rpc_msg_graph_recompute_req request;
|
||||
request.device = rpc_ctx->device;
|
||||
auto sock = get_socket(rpc_ctx->endpoint);
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
|
||||
RPC_STATUS_ASSERT(status);
|
||||
} else {
|
||||
rpc_ctx->gc.add(cgraph);
|
||||
std::vector<uint8_t> input;
|
||||
serialize_graph(rpc_ctx->device, cgraph, input);
|
||||
auto sock = get_socket(rpc_ctx->endpoint);
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
|
||||
RPC_STATUS_ASSERT(status);
|
||||
}
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
static ggml_backend_i ggml_backend_rpc_interface = {
|
||||
|
|
@ -880,7 +918,8 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
|
|||
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
||||
/* .endpoint = */ endpoint,
|
||||
/* .device = */ device,
|
||||
/* .name = */ dev_name
|
||||
/* .name = */ dev_name,
|
||||
/* .gc = */ {},
|
||||
};
|
||||
auto reg = ggml_backend_rpc_add_server(endpoint);
|
||||
ggml_backend_t backend = new ggml_backend {
|
||||
|
|
@ -920,8 +959,9 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device,
|
|||
|
||||
class rpc_server {
|
||||
public:
|
||||
rpc_server(std::vector<ggml_backend_t> backends, const char * cache_dir)
|
||||
: backends(std::move(backends)), cache_dir(cache_dir) {
|
||||
rpc_server(std::vector<ggml_backend_t> all_backends, const char * cache_dir)
|
||||
: backends(std::move(all_backends)), cache_dir(cache_dir) {
|
||||
stored_graphs.resize(backends.size());
|
||||
}
|
||||
~rpc_server();
|
||||
|
||||
|
|
@ -936,11 +976,17 @@ public:
|
|||
bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
|
||||
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
|
||||
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
|
||||
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
|
||||
bool graph_compute(const std::vector<uint8_t> & input);
|
||||
bool graph_recompute(const rpc_msg_graph_recompute_req & request);
|
||||
bool init_tensor(const rpc_msg_init_tensor_req & request);
|
||||
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
|
||||
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
|
||||
|
||||
struct stored_graph {
|
||||
ggml_context_ptr ctx_ptr;
|
||||
ggml_cgraph * graph;
|
||||
};
|
||||
|
||||
private:
|
||||
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
|
||||
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
|
||||
|
|
@ -953,6 +999,8 @@ private:
|
|||
std::vector<ggml_backend_t> backends;
|
||||
const char * cache_dir;
|
||||
std::unordered_set<ggml_backend_buffer_t> buffers;
|
||||
// store the last computed graph for each backend
|
||||
std::vector<stored_graph> stored_graphs;
|
||||
};
|
||||
|
||||
void rpc_server::hello(rpc_msg_hello_rsp & response) {
|
||||
|
|
@ -1394,7 +1442,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
|
|||
return result;
|
||||
}
|
||||
|
||||
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
|
||||
bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
|
||||
// serialization format:
|
||||
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
||||
if (input.size() < 2*sizeof(uint32_t)) {
|
||||
|
|
@ -1455,7 +1503,24 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
|
|||
}
|
||||
}
|
||||
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
|
||||
response.result = status;
|
||||
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
|
||||
stored_graphs[device].ctx_ptr.swap(ctx_ptr);
|
||||
stored_graphs[device].graph = graph;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) {
|
||||
uint32_t device = request.device;
|
||||
if (device >= backends.size()) {
|
||||
return false;
|
||||
}
|
||||
if (stored_graphs[device].graph == nullptr) {
|
||||
return false;
|
||||
}
|
||||
ggml_cgraph * graph = stored_graphs[device].graph;
|
||||
LOG_DBG("[%s] device: %u\n", __func__, device);
|
||||
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
|
||||
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -1690,11 +1755,17 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
|
|||
if (!recv_msg(sockfd, input)) {
|
||||
return;
|
||||
}
|
||||
rpc_msg_graph_compute_rsp response;
|
||||
if (!server.graph_compute(input, response)) {
|
||||
if (!server.graph_compute(input)) {
|
||||
return;
|
||||
}
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
break;
|
||||
}
|
||||
case RPC_CMD_GRAPH_RECOMPUTE: {
|
||||
rpc_msg_graph_recompute_req request;
|
||||
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||
return;
|
||||
}
|
||||
if (!server.graph_recompute(request)) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -91,7 +91,10 @@ if (GGML_SYCL_F16)
|
|||
add_compile_definitions(GGML_SYCL_F16)
|
||||
endif()
|
||||
|
||||
if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
|
||||
if (GGML_SYCL_TARGET STREQUAL "INTEL")
|
||||
add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
|
||||
target_link_options(ggml-sycl PRIVATE -Xs -ze-intel-greater-than-4GB-buffer-required)
|
||||
elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
|
||||
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
|
||||
elseif (GGML_SYCL_TARGET STREQUAL "AMD")
|
||||
# INFO: Allowed Sub_group_sizes are not consistent through all
|
||||
|
|
@ -100,7 +103,8 @@ elseif (GGML_SYCL_TARGET STREQUAL "AMD")
|
|||
# Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32)
|
||||
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
|
||||
else()
|
||||
add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
|
||||
# default for other target
|
||||
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
|
||||
endif()
|
||||
|
||||
if (GGML_SYCL_GRAPH)
|
||||
|
|
|
|||
|
|
@ -515,9 +515,6 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co
|
|||
const int64_t ne = ggml_nelements(src0);
|
||||
GGML_ASSERT(ne == ggml_nelements(src1));
|
||||
|
||||
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
|
||||
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS01;
|
||||
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
|
|
|||
|
|
@ -613,9 +613,10 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
||||
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
||||
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT];
|
||||
|
||||
vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
||||
vk_pipeline pipeline_dequant_mul_mat_vec_id_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT];
|
||||
|
||||
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
|
||||
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
|
||||
|
|
@ -649,6 +650,7 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_sin_f32;
|
||||
vk_pipeline pipeline_cos_f32;
|
||||
vk_pipeline pipeline_log[2];
|
||||
vk_pipeline pipeline_tri[2];
|
||||
vk_pipeline pipeline_clamp_f32;
|
||||
vk_pipeline pipeline_pad_f32;
|
||||
vk_pipeline pipeline_roll_f32;
|
||||
|
|
@ -1610,7 +1612,7 @@ class vk_perf_logger {
|
|||
}
|
||||
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
|
||||
const uint64_t m = node->src[0]->ne[1];
|
||||
const uint64_t n = node->ne[1];
|
||||
const uint64_t n = (node->op == GGML_OP_MUL_MAT) ? node->ne[1] : node->ne[2];
|
||||
const uint64_t k = node->src[1]->ne[0];
|
||||
const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3];
|
||||
std::string name = ggml_op_name(node->op);
|
||||
|
|
@ -3524,13 +3526,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
// the number of rows computed per shader depends on GPU model and quant
|
||||
uint32_t rm_stdq = 1;
|
||||
uint32_t rm_kq = 2;
|
||||
uint32_t rm_stdq_int = 1;
|
||||
uint32_t rm_kq_int = 1;
|
||||
if (device->vendor_id == VK_VENDOR_ID_AMD) {
|
||||
if (device->architecture == AMD_GCN) {
|
||||
rm_stdq = 2;
|
||||
rm_kq = 4;
|
||||
rm_stdq_int = 4;
|
||||
}
|
||||
} else if (device->vendor_id == VK_VENDOR_ID_INTEL)
|
||||
} else if (device->vendor_id == VK_VENDOR_ID_INTEL) {
|
||||
rm_stdq = 2;
|
||||
rm_stdq_int = 2;
|
||||
}
|
||||
uint32_t rm_iq = 2 * rm_kq;
|
||||
|
||||
const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;
|
||||
|
|
@ -3611,39 +3618,73 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
|
||||
const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_q8_1_f32", arr_dmmv_mxfp4_q8_1_f32_len[reduc], arr_dmmv_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_q8_1_f32", arr_dmmv_q2_k_q8_1_f32_len[reduc], arr_dmmv_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_q8_1_f32", arr_dmmv_q3_k_q8_1_f32_len[reduc], arr_dmmv_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||
}
|
||||
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", arr_dmmv_id_q4_1_f32_f32_len[reduc], arr_dmmv_id_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", arr_dmmv_id_q5_0_f32_f32_len[reduc], arr_dmmv_id_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", arr_dmmv_id_q5_1_f32_f32_len[reduc], arr_dmmv_id_q5_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", arr_dmmv_id_q8_0_f32_f32_len[reduc], arr_dmmv_id_q8_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", arr_dmmv_id_q2_k_f32_f32_len[reduc16], arr_dmmv_id_q2_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", arr_dmmv_id_q3_k_f32_f32_len[reduc16], arr_dmmv_id_q3_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", arr_dmmv_id_q4_k_f32_f32_len[reduc16], arr_dmmv_id_q4_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", arr_dmmv_id_q5_k_f32_f32_len[reduc16], arr_dmmv_id_q5_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", arr_dmmv_id_q6_k_f32_f32_len[reduc16], arr_dmmv_id_q6_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", arr_dmmv_id_iq1_s_f32_f32_len[reduc16], arr_dmmv_id_iq1_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", arr_dmmv_id_iq1_m_f32_f32_len[reduc16], arr_dmmv_id_iq1_m_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", arr_dmmv_id_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", arr_dmmv_id_iq2_xs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", arr_dmmv_id_iq2_s_f32_f32_len[reduc16], arr_dmmv_id_iq2_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", arr_dmmv_id_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq3_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", arr_dmmv_id_iq3_s_f32_f32_len[reduc16], arr_dmmv_id_iq3_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product) {
|
||||
const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
|
||||
const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||
}
|
||||
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
||||
#if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
GGML_UNUSED(rm_stdq_int);
|
||||
GGML_UNUSED(rm_kq_int);
|
||||
#endif
|
||||
|
||||
// dequant shaders
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||
|
|
@ -3876,6 +3917,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
|
@ -5449,6 +5493,12 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
|||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
break;
|
||||
default:
|
||||
return nullptr;
|
||||
|
|
@ -5588,9 +5638,28 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
|||
}
|
||||
}
|
||||
|
||||
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
|
||||
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t m, uint32_t k) {
|
||||
VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()");
|
||||
GGML_ASSERT(b_type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_Q8_1);
|
||||
|
||||
if (b_type == GGML_TYPE_Q8_1) {
|
||||
switch (a_type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
break;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
switch (a_type) {
|
||||
case GGML_TYPE_F32:
|
||||
|
|
@ -5621,7 +5690,31 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type];
|
||||
// heuristic to choose workgroup size
|
||||
uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
|
||||
if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
|
||||
// Prefer larger workgroups when M is small, to spread the work out more
|
||||
// and keep more SMs busy.
|
||||
// q6_k seems to prefer small workgroup size even for "medium" values of M.
|
||||
if (a_type == GGML_TYPE_Q6_K) {
|
||||
if (m < 4096 && k >= 1024) {
|
||||
dmmv_wg = DMMV_WG_SIZE_LARGE;
|
||||
}
|
||||
} else {
|
||||
if (m <= 8192 && k >= 1024) {
|
||||
dmmv_wg = DMMV_WG_SIZE_LARGE;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (b_type == GGML_TYPE_Q8_1) {
|
||||
if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
|
||||
dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
|
||||
}
|
||||
return ctx->device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[dmmv_wg][a_type];
|
||||
}
|
||||
|
||||
return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[dmmv_wg][a_type];
|
||||
}
|
||||
|
||||
static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
|
||||
|
|
@ -6813,20 +6906,35 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
|
|||
return false;
|
||||
}
|
||||
|
||||
// General performance issue with q3_k and q6_k due to 2-byte alignment
|
||||
if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// MMVQ is generally good for batches
|
||||
if (n > 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Quantization overhead is not worth it for small k
|
||||
switch (device->vendor_id) {
|
||||
case VK_VENDOR_ID_NVIDIA:
|
||||
if (k <= 4096) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (src0_type) {
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_Q8_0:
|
||||
return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
case VK_VENDOR_ID_AMD:
|
||||
if (k < 2048) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (src0_type) {
|
||||
case GGML_TYPE_Q8_0:
|
||||
return device->architecture == vk_device_architecture::AMD_GCN;
|
||||
|
|
@ -6834,6 +6942,10 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
|
|||
return true;
|
||||
}
|
||||
case VK_VENDOR_ID_INTEL:
|
||||
if (k < 2048) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (src0_type) {
|
||||
// From tests on A770 Linux, may need more tuning
|
||||
case GGML_TYPE_Q4_0:
|
||||
|
|
@ -6847,7 +6959,6 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
|
|||
}
|
||||
|
||||
GGML_UNUSED(m);
|
||||
GGML_UNUSED(k);
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
||||
|
|
@ -7570,7 +7681,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|||
if (x_non_contig || qx_needs_dequant) {
|
||||
ctx->prealloc_x_need_sync = true;
|
||||
}
|
||||
if (y_non_contig) {
|
||||
if (y_non_contig || quantize_y) {
|
||||
ctx->prealloc_y_need_sync = true;
|
||||
}
|
||||
}
|
||||
|
|
@ -7596,7 +7707,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|||
|
||||
const uint64_t ne10 = src1->ne[0];
|
||||
const uint64_t ne11 = src1->ne[1];
|
||||
// const uint64_t ne12 = src1->ne[2];
|
||||
const uint64_t ne12 = src1->ne[2];
|
||||
// const uint64_t ne13 = src1->ne[3];
|
||||
|
||||
const uint64_t nei0 = ids->ne[0];
|
||||
|
|
@ -7613,19 +7724,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|||
const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
|
||||
|
||||
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
|
||||
|
||||
const bool qx_needs_dequant = x_non_contig;
|
||||
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
|
||||
|
||||
// Not implemented
|
||||
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
||||
|
||||
const uint64_t x_ne = ggml_nelements(src0);
|
||||
const uint64_t y_ne = ggml_nelements(src1);
|
||||
|
||||
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
|
||||
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
|
||||
const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
|
||||
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne12, ne10, src0->type);
|
||||
|
||||
vk_pipeline to_fp16_vk_0 = nullptr;
|
||||
vk_pipeline to_fp16_vk_1 = nullptr;
|
||||
|
|
@ -7637,11 +7736,38 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|||
} else {
|
||||
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
||||
}
|
||||
vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type);
|
||||
|
||||
// Check for mmq first
|
||||
vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, GGML_TYPE_Q8_1, ne20, ne00) : nullptr;
|
||||
vk_pipeline to_q8_1 = nullptr;
|
||||
|
||||
if (dmmv == nullptr) {
|
||||
// Fall back to f16 dequant mul mat
|
||||
dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type, ne20, ne00);
|
||||
quantize_y = false;
|
||||
}
|
||||
|
||||
if (quantize_y) {
|
||||
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
|
||||
}
|
||||
|
||||
const bool qx_needs_dequant = x_non_contig;
|
||||
const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
|
||||
|
||||
// Not implemented
|
||||
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
||||
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
|
||||
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
|
||||
GGML_ASSERT(dmmv != nullptr);
|
||||
|
||||
const uint64_t x_ne = ggml_nelements(src0);
|
||||
const uint64_t y_ne = ggml_nelements(src1);
|
||||
|
||||
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
|
||||
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
|
||||
const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) :
|
||||
(f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
|
||||
|
||||
{
|
||||
if (
|
||||
(qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
|
||||
|
|
@ -7652,7 +7778,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|||
ctx->prealloc_size_x = x_sz;
|
||||
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||
}
|
||||
if (qy_needs_dequant && ctx->prealloc_size_y < y_sz) {
|
||||
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {
|
||||
ctx->prealloc_size_y = y_sz;
|
||||
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||
}
|
||||
|
|
@ -7664,6 +7790,9 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|||
if (qy_needs_dequant) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
|
||||
}
|
||||
if (quantize_y) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
|
||||
}
|
||||
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
|
||||
}
|
||||
|
||||
|
|
@ -7679,7 +7808,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|||
} else {
|
||||
d_X = d_Qx;
|
||||
}
|
||||
if (qy_needs_dequant) {
|
||||
if (qy_needs_dequant || quantize_y) {
|
||||
d_Y = { ctx->prealloc_y, 0, ctx->prealloc_y->size };
|
||||
} else {
|
||||
d_Y = d_Qy;
|
||||
|
|
@ -7707,6 +7836,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|||
ctx->prealloc_y_last_tensor_used = src1;
|
||||
}
|
||||
}
|
||||
if (quantize_y) {
|
||||
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
|
||||
ctx->prealloc_y_last_tensor_used != src1) {
|
||||
if (ctx->prealloc_y_need_sync) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);
|
||||
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
|
||||
ctx->prealloc_y_last_tensor_used = src1;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t stride_batch_y = ne10*ne11;
|
||||
|
||||
|
|
@ -7768,7 +7908,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|||
if (x_non_contig) {
|
||||
ctx->prealloc_x_need_sync = true;
|
||||
}
|
||||
if (y_non_contig) {
|
||||
if (y_non_contig || quantize_y) {
|
||||
ctx->prealloc_y_need_sync = true;
|
||||
}
|
||||
}
|
||||
|
|
@ -8290,6 +8430,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_TRI:
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CLAMP:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_clamp_f32;
|
||||
|
|
@ -8991,6 +9137,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_LOG:
|
||||
case GGML_OP_TRI:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ROLL:
|
||||
|
|
@ -9671,6 +9818,13 @@ static void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst));
|
||||
}
|
||||
|
||||
static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||
p.param1 = ggml_get_op_params_f32(dst, 0);
|
||||
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p));
|
||||
}
|
||||
|
||||
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||
p.param1 = ggml_get_op_params_f32(dst, 0);
|
||||
|
|
@ -10221,7 +10375,9 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|||
|
||||
// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
|
||||
// But if K is larger, then we need a larger workgroup
|
||||
uint32_t max_pipeline = num_topk_pipelines - 3;
|
||||
uint32_t max_pipeline = num_topk_pipelines - 1;
|
||||
uint32_t preferred_pipeline = std::max(num_topk_pipelines - 3, (uint32_t)log2f(float(k)) + 2);
|
||||
max_pipeline = std::min(preferred_pipeline, max_pipeline);
|
||||
uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
|
||||
// require full subgroup
|
||||
min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);
|
||||
|
|
@ -11794,6 +11950,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
case GGML_OP_LOG:
|
||||
ggml_vk_log(ctx, compute_ctx, src0, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_TRI:
|
||||
ggml_vk_tri(ctx, compute_ctx, src0, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_CLAMP:
|
||||
ggml_vk_clamp(ctx, compute_ctx, src0, node);
|
||||
|
|
@ -13919,7 +14079,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
case GGML_OP_OPT_STEP_SGD:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_LOG:
|
||||
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
|
||||
case GGML_OP_TRI:
|
||||
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
op->type == op->src[0]->type;
|
||||
case GGML_OP_ARGSORT:
|
||||
{
|
||||
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
|
||||
|
|
@ -14510,6 +14672,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
|
||||
} else if (tensor->op == GGML_OP_LOG) {
|
||||
tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
|
||||
} else if (tensor->op == GGML_OP_TRI) {
|
||||
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
|
||||
} else if (tensor->op == GGML_OP_CLAMP) {
|
||||
const float * params = (const float *)tensor->op_params;
|
||||
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
|
||||
|
|
|
|||
|
|
@ -4,13 +4,6 @@
|
|||
|
||||
#include "types.glsl"
|
||||
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
|
||||
#endif
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_F32)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
|
||||
|
|
|
|||
|
|
@ -156,7 +156,7 @@ void main() {
|
|||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv, mvmax;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,13 @@ layout (push_constant) uniform parameter
|
|||
|
||||
#if !RMS_NORM_ROPE_FUSION
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
|
||||
#endif
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
||||
#endif
|
||||
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -18,6 +18,13 @@ layout (push_constant) uniform parameter
|
|||
} p;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
|
||||
#endif
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
||||
#endif
|
||||
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
uint get_idx() {
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#include "mul_mat_vec_base.glsl"
|
||||
#include "dequant_funcs.glsl"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
|
|
|
|||
|
|
@ -13,8 +13,6 @@
|
|||
|
||||
#include "mul_mat_vec_iface.glsl"
|
||||
|
||||
#include "dequant_funcs.glsl"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint ncols;
|
||||
|
|
|
|||
|
|
@ -5,13 +5,15 @@
|
|||
#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4
|
||||
#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
|
||||
|
||||
#ifndef MMQ
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
#if defined(A_TYPE_VEC4)
|
||||
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
|
||||
#endif
|
||||
#else
|
||||
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
|
||||
#endif
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
||||
#endif
|
||||
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
|
|
|
|||
|
|
@ -10,60 +10,56 @@
|
|||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
|
||||
#define K_PER_ITER 8
|
||||
|
||||
#include "mul_mmq_funcs.glsl"
|
||||
#elif defined(DATA_A_QUANT_K)
|
||||
#define K_PER_ITER 16
|
||||
#else
|
||||
#error unimplemented
|
||||
#endif
|
||||
|
||||
uint a_offset, b_offset, d_offset;
|
||||
|
||||
int32_t cache_b_qs[2];
|
||||
int32_t cache_b_qs[K_PER_ITER / 4];
|
||||
vec2 cache_b_ds;
|
||||
|
||||
#include "mul_mat_vecq_funcs.glsl"
|
||||
|
||||
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
|
||||
|
||||
// Preload data_b block
|
||||
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
|
||||
const uint b_qs_idx = tid % 4;
|
||||
const uint b_qs_idx = tid % (32 / K_PER_ITER);
|
||||
const uint b_block_idx_outer = b_block_idx / 4;
|
||||
const uint b_block_idx_inner = b_block_idx % 4;
|
||||
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
|
||||
|
||||
#if QUANT_R == 2
|
||||
// Assumes K_PER_ITER == 8
|
||||
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
|
||||
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
|
||||
#else
|
||||
#if K_PER_ITER == 8
|
||||
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
|
||||
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
|
||||
#elif K_PER_ITER == 16
|
||||
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 ];
|
||||
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1];
|
||||
cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2];
|
||||
cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3];
|
||||
#else
|
||||
#error unimplemented
|
||||
#endif
|
||||
#endif
|
||||
|
||||
uint ibi = first_row*p.ncols;
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint a_block_idx = (ibi + col)/QUANT_K + a_offset;
|
||||
const uint a_block_idx = (ibi + col)/QUANT_K_Q8_1 + a_offset;
|
||||
ibi += p.ncols;
|
||||
|
||||
int32_t q_sum = 0;
|
||||
#if QUANT_R == 2
|
||||
const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx);
|
||||
q_sum += dotPacked4x8EXT(data_a_qs.x,
|
||||
cache_b_qs[0]);
|
||||
q_sum += dotPacked4x8EXT(data_a_qs.y,
|
||||
cache_b_qs[1]);
|
||||
#else
|
||||
int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2);
|
||||
q_sum += dotPacked4x8EXT(data_a_qs,
|
||||
cache_b_qs[0]);
|
||||
data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1);
|
||||
q_sum += dotPacked4x8EXT(data_a_qs,
|
||||
cache_b_qs[1]);
|
||||
#endif
|
||||
|
||||
#if QUANT_AUXF == 1
|
||||
temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4);
|
||||
#else
|
||||
temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4);
|
||||
#endif
|
||||
temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -72,7 +68,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
get_offsets(a_offset, b_offset, d_offset);
|
||||
a_offset /= QUANT_K;
|
||||
a_offset /= QUANT_K_Q8_1;
|
||||
b_offset /= QUANT_K_Q8_1;
|
||||
|
||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||
|
|
@ -102,14 +98,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
unroll_count = 2;
|
||||
unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||
|
||||
#if K_PER_ITER == 2
|
||||
if ((p.ncols & 1) != 0 &&
|
||||
unrolled_iters == num_iters &&
|
||||
unrolled_iters > 0) {
|
||||
unrolled_iters -= unroll_count;
|
||||
}
|
||||
#endif
|
||||
|
||||
while (i < unrolled_iters) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
|
|
@ -128,6 +116,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
void main() {
|
||||
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
|
||||
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
// do NUM_ROWS at a time, unless there aren't enough remaining rows
|
||||
if (first_row + NUM_ROWS <= p.stride_d) {
|
||||
compute_outputs(first_row, NUM_ROWS);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,379 @@
|
|||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
FLOAT_TYPE get_dm(uint ib) {
|
||||
return FLOAT_TYPE(data_a[ib].d);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
||||
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_MXFP4)
|
||||
FLOAT_TYPE get_dm(uint ib) {
|
||||
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q2_K)
|
||||
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
||||
const uint ib_k = ib / 8;
|
||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Each iqs value maps to a 32-bit integer
|
||||
#if defined(DATA_A_Q4_0)
|
||||
// 2-byte loads for Q4_0 blocks (18 bytes)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]);
|
||||
const uint32_t vui = pack32(quants);
|
||||
return i32vec2( vui & 0x0F0F0F0F,
|
||||
(vui >> 4) & 0x0F0F0F0F);
|
||||
}
|
||||
|
||||
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_1)
|
||||
// 4-byte loads for Q4_1 blocks (20 bytes)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
const uint32_t vui = data_a_packed32[ib].qs[iqs];
|
||||
return i32vec2( vui & 0x0F0F0F0F,
|
||||
(vui >> 4) & 0x0F0F0F0F);
|
||||
}
|
||||
|
||||
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_0)
|
||||
// 2-byte loads for Q5_0 blocks (22 bytes)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]);
|
||||
const uint32_t vui = pack32(quants);
|
||||
const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
|
||||
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
|
||||
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
||||
|
||||
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
|
||||
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
|
||||
|
||||
return i32vec2(v0, v1);
|
||||
}
|
||||
|
||||
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_1)
|
||||
// 4-byte loads for Q5_1 blocks (24 bytes)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]);
|
||||
const uint32_t vui = pack32(quants);
|
||||
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
|
||||
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
|
||||
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
||||
|
||||
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
|
||||
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
|
||||
|
||||
return i32vec2(v0, v1);
|
||||
}
|
||||
|
||||
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
// 2-byte loads for Q8_0 blocks (34 bytes)
|
||||
int32_t repack(uint ib, uint iqs) {
|
||||
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||
}
|
||||
|
||||
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return FLOAT_TYPE(float(q_sum) * da * dsb.x);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_MXFP4)
|
||||
// 1-byte loads for mxfp4 blocks (17 bytes)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
|
||||
data_a[ib].qs[iqs * 4 + 1],
|
||||
data_a[ib].qs[iqs * 4 + 2],
|
||||
data_a[ib].qs[iqs * 4 + 3]));
|
||||
|
||||
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
|
||||
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
|
||||
|
||||
return i32vec2(pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w])),
|
||||
pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w])));
|
||||
}
|
||||
|
||||
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return FLOAT_TYPE(da * dsb.x * float(q_sum) * 0.5);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
|
||||
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
|
||||
int32_t q_sum = 0;
|
||||
#if QUANT_R == 2
|
||||
const i32vec2 data_a_qs = repack(ib_a, iqs);
|
||||
q_sum += dotPacked4x8EXT(data_a_qs.x,
|
||||
cache_b_qs[0]);
|
||||
q_sum += dotPacked4x8EXT(data_a_qs.y,
|
||||
cache_b_qs[1]);
|
||||
#else
|
||||
int32_t data_a_qs = repack(ib_a, iqs * 2);
|
||||
q_sum += dotPacked4x8EXT(data_a_qs,
|
||||
cache_b_qs[0]);
|
||||
data_a_qs = repack(ib_a, iqs * 2 + 1);
|
||||
q_sum += dotPacked4x8EXT(data_a_qs,
|
||||
cache_b_qs[1]);
|
||||
#endif
|
||||
|
||||
// 2 quants per call => divide sums by 8/2 = 4
|
||||
return mul_q8_1(q_sum, get_dm(ib_a), cache_b_ds, 4);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q2_K)
|
||||
// 4-byte loads for Q2_K blocks (84 bytes)
|
||||
i32vec4 repack4(uint ib, uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
|
||||
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
|
||||
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
|
||||
|
||||
return i32vec4((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303,
|
||||
(data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303,
|
||||
(data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303,
|
||||
(data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303);
|
||||
}
|
||||
|
||||
uint8_t get_scale(uint ib, uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
|
||||
return data_a[ib_k].scales[iqs_k / 4];
|
||||
}
|
||||
|
||||
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
|
||||
int32_t sum_d = 0;
|
||||
int32_t sum_m = 0;
|
||||
|
||||
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
|
||||
const uint8_t scale = get_scale(ib_a, iqs * 4);
|
||||
const vec2 dm = vec2(get_dm(ib_a));
|
||||
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
|
||||
|
||||
sum_d += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]) * (scale & 0xF);
|
||||
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[0]);
|
||||
|
||||
sum_d += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]) * (scale & 0xF);
|
||||
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[1]);
|
||||
|
||||
sum_d += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]) * (scale & 0xF);
|
||||
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[2]);
|
||||
|
||||
sum_d += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]) * (scale & 0xF);
|
||||
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[3]);
|
||||
|
||||
return FLOAT_TYPE(float(cache_b_ds.x) * (float(dm.x) * float(sum_d) - float(dm.y) * float(sum_m)));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q3_K)
|
||||
// 2-byte loads for Q3_K blocks (110 bytes)
|
||||
i32vec4 repack4(uint ib, uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
|
||||
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
|
||||
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
|
||||
const uint hm_shift = iqs_k / 8;
|
||||
|
||||
// bitwise OR to add 4 if hmask is set, subtract later
|
||||
const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals20 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 4] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 4] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals21 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 5] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 5] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals30 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 6] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 6] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals31 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 7] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 7] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
|
||||
return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y) - int8_t(4)),
|
||||
pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y) - int8_t(4)),
|
||||
pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y) - int8_t(4)),
|
||||
pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y) - int8_t(4)));
|
||||
}
|
||||
|
||||
float get_d_scale(uint ib, uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
const uint is = iqs_k / 4;
|
||||
|
||||
const int8_t scale = int8_t(((data_a[ib_k].scales[is % 8 ] >> (4 * (is / 8))) & 0x0F0F) |
|
||||
(((data_a[ib_k].scales[8 + (is % 4)] >> (2 * (is / 4))) & 0x0303) << 4));
|
||||
return float(data_a[ib_k].d) * float(scale - 32);
|
||||
}
|
||||
|
||||
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
|
||||
int32_t q_sum = 0;
|
||||
|
||||
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
|
||||
const float d_scale = get_d_scale(ib_a, iqs * 4);
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
|
||||
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
|
||||
q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
|
||||
q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
|
||||
|
||||
return FLOAT_TYPE(float(cache_b_ds.x) * d_scale * float(q_sum));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
|
||||
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
|
||||
i32vec4 repack4(uint ib, uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
|
||||
const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
|
||||
const uint qs_shift = ((iqs_k % 16) / 8) * 4;
|
||||
|
||||
#if defined(DATA_A_Q4_K)
|
||||
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
|
||||
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
|
||||
const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F;
|
||||
const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F;
|
||||
|
||||
return i32vec4(vals0, vals1, vals2, vals3);
|
||||
#else // defined(DATA_A_Q5_K)
|
||||
const uint qh_idx = iqs;
|
||||
const uint qh_shift = iqs_k / 8;
|
||||
|
||||
return i32vec4(((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F) |
|
||||
(((data_a_packed32[ib_k].qh[qh_idx ] >> qh_shift) & 0x01010101) << 4),
|
||||
((data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F) |
|
||||
(((data_a_packed32[ib_k].qh[qh_idx + 1] >> qh_shift) & 0x01010101) << 4),
|
||||
((data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F) |
|
||||
(((data_a_packed32[ib_k].qh[qh_idx + 2] >> qh_shift) & 0x01010101) << 4),
|
||||
((data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F) |
|
||||
(((data_a_packed32[ib_k].qh[qh_idx + 3] >> qh_shift) & 0x01010101) << 4));
|
||||
#endif
|
||||
}
|
||||
|
||||
vec2 get_dm_scale(uint ib, uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
const uint is = iqs_k / 8;
|
||||
u8vec2 scale_dm;
|
||||
if (is < 4) {
|
||||
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
|
||||
} else {
|
||||
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
|
||||
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
|
||||
}
|
||||
|
||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
|
||||
}
|
||||
|
||||
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
|
||||
int32_t q_sum = 0;
|
||||
|
||||
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
|
||||
const vec2 dm_scale = get_dm_scale(ib_a, iqs * 4);
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
|
||||
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
|
||||
q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
|
||||
q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
|
||||
|
||||
return FLOAT_TYPE(float(cache_b_ds.x) * float(dm_scale.x) * float(q_sum) - float(dm_scale.y) * float(cache_b_ds.y / 2));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q6_K)
|
||||
// 2-byte loads for Q6_K blocks (210 bytes)
|
||||
i32vec4 repack4(uint ib, uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
|
||||
const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
|
||||
const uint ql_shift = ((iqs_k % 32) / 16) * 4;
|
||||
|
||||
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
|
||||
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
|
||||
|
||||
const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||
const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||
const i8vec2 vals10 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 2] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 2] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||
const i8vec2 vals11 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 3] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 3] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||
const i8vec2 vals20 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 4] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 4] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||
const i8vec2 vals21 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 5] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 5] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||
const i8vec2 vals30 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 6] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 6] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||
const i8vec2 vals31 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 7] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 7] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||
|
||||
return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)),
|
||||
pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y)),
|
||||
pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y)),
|
||||
pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y)));
|
||||
}
|
||||
|
||||
float get_d_scale(uint ib, uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
return float(data_a[ib_k].d) * float(data_a[ib_k].scales[iqs_k / 4]);
|
||||
}
|
||||
|
||||
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
|
||||
int32_t q_sum = 0;
|
||||
|
||||
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
|
||||
const float d_scale = get_d_scale(ib_a, iqs * 4);
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
|
||||
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
|
||||
q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
|
||||
q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
|
||||
|
||||
return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum));
|
||||
}
|
||||
#endif
|
||||
|
|
@ -78,8 +78,6 @@ layout (constant_id = 10) const uint WARP = 32;
|
|||
|
||||
#define BK 32
|
||||
|
||||
#define MMQ_SHMEM
|
||||
|
||||
#include "mul_mmq_shmem_types.glsl"
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
|
|
|
|||
|
|
@ -9,31 +9,6 @@
|
|||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
||||
// 2-byte loads for Q4_0 blocks (18 bytes)
|
||||
// 4-byte loads for Q4_1 blocks (20 bytes)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
#ifdef DATA_A_Q4_0
|
||||
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]);
|
||||
const uint32_t vui = pack32(quants);
|
||||
return i32vec2( vui & 0x0F0F0F0F,
|
||||
(vui >> 4) & 0x0F0F0F0F);
|
||||
#else // DATA_A_Q4_1
|
||||
const uint32_t vui = data_a_packed32[ib].qs[iqs];
|
||||
return i32vec2( vui & 0x0F0F0F0F,
|
||||
(vui >> 4) & 0x0F0F0F0F);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef DATA_A_Q4_0
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
|
||||
}
|
||||
#else // DATA_A_Q4_1
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
#ifdef DATA_A_Q4_0
|
||||
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
|
||||
|
|
@ -73,42 +48,17 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
|
|||
q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
|
||||
}
|
||||
|
||||
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||
#ifdef DATA_A_Q4_0
|
||||
return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 8.0 * float(cache_b.ds.y)));
|
||||
#else // DATA_A_Q4_1
|
||||
return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
|
||||
#endif
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
// 2-byte loads for Q5_0 blocks (22 bytes)
|
||||
// 4-byte loads for Q5_1 blocks (24 bytes)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]);
|
||||
const uint32_t vui = pack32(quants);
|
||||
#ifdef DATA_A_Q5_0
|
||||
const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
|
||||
#else // DATA_A_Q5_1
|
||||
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
|
||||
#endif
|
||||
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
|
||||
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
||||
|
||||
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
|
||||
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
|
||||
|
||||
return i32vec2(v0, v1);
|
||||
}
|
||||
|
||||
#ifdef DATA_A_Q5_0
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
|
||||
}
|
||||
#else // DATA_A_Q5_1
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
#ifdef DATA_A_Q5_0
|
||||
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
|
||||
|
|
@ -154,23 +104,16 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
|
|||
q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
|
||||
}
|
||||
|
||||
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||
#ifdef DATA_A_Q5_0
|
||||
return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 16.0 * float(cache_b.ds.y)));
|
||||
#else // DATA_A_Q5_1
|
||||
return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
|
||||
#endif
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
// 2-byte loads for Q8_0 blocks (34 bytes)
|
||||
int32_t repack(uint ib, uint iqs) {
|
||||
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(float(q_sum) * da * dsb.x);
|
||||
}
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||
|
|
@ -197,28 +140,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
|
|||
q_sum += dotPacked4x8EXT(qs_a, qs_b);
|
||||
}
|
||||
|
||||
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||
return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm) * float(cache_b.ds.x));
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_MXFP4)
|
||||
// 1-byte loads for mxfp4 blocks (17 bytes)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
|
||||
data_a[ib].qs[iqs * 4 + 1],
|
||||
data_a[ib].qs[iqs * 4 + 2],
|
||||
data_a[ib].qs[iqs * 4 + 3]));
|
||||
|
||||
return i32vec2( quants & 0x0F0F0F0F,
|
||||
(quants >> 4) & 0x0F0F0F0F);
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(da * dsb.x * float(q_sum));
|
||||
}
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
|
||||
data_a[ib].qs[iqs * 4 + 1],
|
||||
|
|
@ -252,37 +179,14 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
|
|||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||
}
|
||||
|
||||
return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1);
|
||||
return ACC_TYPE(float(cache_a[ib_a].d) * float(cache_b.ds.x) * float(q_sum));
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
|
||||
// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
|
||||
#if defined(DATA_A_Q2_K)
|
||||
// 4-byte loads for Q2_K blocks (84 bytes)
|
||||
int32_t repack(uint ib, uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
|
||||
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
|
||||
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
|
||||
|
||||
return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
|
||||
}
|
||||
|
||||
uint8_t get_scale(uint ib, uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
|
||||
return data_a[ib_k].scales[iqs_k / 4];
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
|
||||
}
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
|
||||
|
|
@ -326,14 +230,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
|
|||
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
|
||||
}
|
||||
|
||||
return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||
return ACC_TYPE(float(cache_b.ds.x) * (float(cache_a[ib_a].dm.x) * float(sum_d) - float(cache_a[ib_a].dm.y) * float(sum_m)));
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q3_K)
|
||||
// 2-byte loads for Q3_K blocks (110 bytes)
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint hm_idx = iqs * QUANT_R_MMQ;
|
||||
|
|
@ -394,18 +296,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
|
|||
}
|
||||
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
|
||||
|
||||
return ACC_TYPE(cache_b.ds.x * result);
|
||||
return ACC_TYPE(float(cache_b.ds.x) * result);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
|
||||
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
|
||||
}
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
|
||||
|
|
@ -427,7 +323,6 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
|||
(((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
|
||||
#endif
|
||||
|
||||
|
||||
if (iqs == 0) {
|
||||
// Scale index
|
||||
const uint is = iqs_k / 8;
|
||||
|
|
@ -464,49 +359,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
|
|||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||
}
|
||||
|
||||
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) {
|
||||
if (is_in_bounds) {
|
||||
const uint ib_outer = ib / 4;
|
||||
const uint ib_inner = ib % 4;
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
|
||||
}
|
||||
|
||||
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
|
||||
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
|
||||
} else {
|
||||
if (iqs == 0) {
|
||||
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f);
|
||||
}
|
||||
|
||||
buf_b[buf_ib].qs[iqs * 4 ] = 0;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 1] = 0;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 2] = 0;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 3] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void block_b_to_registers(const uint ib) {
|
||||
cache_b.ds = buf_b[ib].ds;
|
||||
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
|
||||
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
|
||||
}
|
||||
return ACC_TYPE(float(cache_b.ds.x) * float(cache_a[ib_a].dm.x) * float(q_sum) - float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q6_K)
|
||||
// 2-byte loads for Q6_K blocks (210 bytes)
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
|
|
@ -558,32 +416,39 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
|
|||
}
|
||||
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
|
||||
|
||||
return ACC_TYPE(cache_b.ds.x * result);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
FLOAT_TYPE get_d(uint ib) {
|
||||
return FLOAT_TYPE(data_a[ib].d);
|
||||
return ACC_TYPE(float(cache_b.ds.x) * result);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_MXFP4)
|
||||
FLOAT_TYPE get_d(uint ib) {
|
||||
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
|
||||
}
|
||||
#endif
|
||||
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) {
|
||||
if (is_in_bounds) {
|
||||
const uint ib_outer = ib / 4;
|
||||
const uint ib_inner = ib % 4;
|
||||
|
||||
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
||||
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||
}
|
||||
#endif
|
||||
if (iqs == 0) {
|
||||
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
|
||||
}
|
||||
|
||||
#if defined(DATA_A_Q2_K)
|
||||
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
||||
const uint ib_k = ib / 8;
|
||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
|
||||
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
|
||||
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
|
||||
} else {
|
||||
if (iqs == 0) {
|
||||
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f);
|
||||
}
|
||||
|
||||
buf_b[buf_ib].qs[iqs * 4 ] = 0;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 1] = 0;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 2] = 0;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 3] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void block_b_to_registers(const uint ib) {
|
||||
cache_b.ds = buf_b[ib].ds;
|
||||
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
|
||||
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -0,0 +1,43 @@
|
|||
#version 450
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
#define GGML_TRI_TYPE_UPPER_DIAG 0
|
||||
#define GGML_TRI_TYPE_UPPER 1
|
||||
#define GGML_TRI_TYPE_LOWER_DIAG 2
|
||||
#define GGML_TRI_TYPE_LOWER 3
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
|
||||
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
|
||||
const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
|
||||
const uint i02_offset = i02*p.ne01*p.ne00;
|
||||
const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
|
||||
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
|
||||
|
||||
int param = floatBitsToInt(p.param1);
|
||||
bool pass = false;
|
||||
switch (param) {
|
||||
case GGML_TRI_TYPE_UPPER_DIAG: pass = i00 >= i01; break;
|
||||
case GGML_TRI_TYPE_UPPER: pass = i00 > i01; break;
|
||||
case GGML_TRI_TYPE_LOWER_DIAG: pass = i00 <= i01; break;
|
||||
case GGML_TRI_TYPE_LOWER: pass = i00 < i01; break;
|
||||
}
|
||||
|
||||
if (pass) {
|
||||
const float val = float(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val);
|
||||
} else {
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0);
|
||||
}
|
||||
}
|
||||
|
|
@ -679,14 +679,20 @@ void process_shaders() {
|
|||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
|
||||
// mul mat vec with integer dot product
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (is_legacy_quant(tname)) {
|
||||
if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname)) {
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -846,6 +852,9 @@ void process_shaders() {
|
|||
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
|
|
@ -1097,7 +1106,7 @@ void write_output_files() {
|
|||
|
||||
for (const std::string& btype : btypes) {
|
||||
for (const auto& tname : type_names) {
|
||||
if (btype == "q8_1" && !is_legacy_quant(tname)) {
|
||||
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname)) {
|
||||
continue;
|
||||
}
|
||||
hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";
|
||||
|
|
@ -1106,6 +1115,16 @@ void write_output_files() {
|
|||
src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n";
|
||||
src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
|
||||
}
|
||||
|
||||
if (btype == "f16") {
|
||||
continue;
|
||||
}
|
||||
hdr << "extern const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3];\n";
|
||||
hdr << "extern const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3];\n";
|
||||
if (basename(input_filepath) == "mul_mat_vec.comp") {
|
||||
src << "const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n";
|
||||
src << "const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -366,6 +366,7 @@ class MODEL_ARCH(IntEnum):
|
|||
QWEN2VL = auto()
|
||||
QWEN3 = auto()
|
||||
QWEN3MOE = auto()
|
||||
QWEN3NEXT = auto()
|
||||
QWEN3VL = auto()
|
||||
QWEN3VLMOE = auto()
|
||||
PHI2 = auto()
|
||||
|
|
@ -531,6 +532,7 @@ class MODEL_TENSOR(IntEnum):
|
|||
SSM_D = auto()
|
||||
SSM_NORM = auto()
|
||||
SSM_OUT = auto()
|
||||
SSM_BETA_ALPHA = auto() # qwen3next
|
||||
TIME_MIX_W0 = auto()
|
||||
TIME_MIX_W1 = auto()
|
||||
TIME_MIX_W2 = auto()
|
||||
|
|
@ -736,6 +738,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.QWEN2VL: "qwen2vl",
|
||||
MODEL_ARCH.QWEN3: "qwen3",
|
||||
MODEL_ARCH.QWEN3MOE: "qwen3moe",
|
||||
MODEL_ARCH.QWEN3NEXT: "qwen3next",
|
||||
MODEL_ARCH.QWEN3VL: "qwen3vl",
|
||||
MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe",
|
||||
MODEL_ARCH.PHI2: "phi2",
|
||||
|
|
@ -900,6 +903,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
||||
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
|
||||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
||||
MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba",
|
||||
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
|
||||
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
|
||||
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
|
||||
|
|
@ -1569,6 +1573,35 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.QWEN3NEXT: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.ATTN_GATE,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.SSM_A,
|
||||
MODEL_TENSOR.SSM_CONV1D,
|
||||
MODEL_TENSOR.SSM_DT,
|
||||
MODEL_TENSOR.SSM_NORM,
|
||||
MODEL_TENSOR.SSM_IN,
|
||||
MODEL_TENSOR.SSM_BETA_ALPHA,
|
||||
MODEL_TENSOR.SSM_OUT
|
||||
],
|
||||
MODEL_ARCH.QWEN3VL: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
|
|
|||
|
|
@ -371,10 +371,13 @@ class GGUFWriter:
|
|||
|
||||
def add_tensor(
|
||||
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
|
||||
raw_dtype: GGMLQuantizationType | None = None,
|
||||
raw_dtype: GGMLQuantizationType | None = None, tensor_endianess: GGUFEndian | None = None
|
||||
) -> None:
|
||||
if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \
|
||||
(self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'):
|
||||
# if tensor endianness is not passed, assume it's native to system
|
||||
if tensor_endianess is None:
|
||||
tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE
|
||||
|
||||
if tensor_endianess != self.endianess:
|
||||
# Don't byteswap inplace since lazy copies cannot handle it
|
||||
tensor = tensor.byteswap(inplace=False)
|
||||
if self.use_temp_file and self.temp_file is None:
|
||||
|
|
@ -397,13 +400,16 @@ class GGUFWriter:
|
|||
if pad != 0:
|
||||
fp.write(bytes([0] * pad))
|
||||
|
||||
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
|
||||
def write_tensor_data(self, tensor: np.ndarray[Any, Any], tensor_endianess: GGUFEndian | None = None) -> None:
|
||||
if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
|
||||
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
|
||||
assert self.fout is not None
|
||||
|
||||
if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \
|
||||
(self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'):
|
||||
# if tensor endianness is not passed, assume it's native to system
|
||||
if tensor_endianess is None:
|
||||
tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE
|
||||
|
||||
if tensor_endianess != self.endianess:
|
||||
# Don't byteswap inplace since lazy copies cannot handle it
|
||||
tensor = tensor.byteswap(inplace=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -1552,7 +1552,7 @@ class GGUFEditorWindow(QMainWindow):
|
|||
|
||||
# Add tensors (including data)
|
||||
for tensor in self.reader.tensors:
|
||||
writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type)
|
||||
writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type, tensor_endianess=self.reader.endianess)
|
||||
|
||||
# Write header and metadata
|
||||
writer.open_output_file(Path(file_path))
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
|
|||
writer.write_ti_data_to_file()
|
||||
|
||||
for tensor in reader.tensors:
|
||||
writer.write_tensor_data(tensor.data)
|
||||
writer.write_tensor_data(tensor.data, tensor_endianess=reader.endianess)
|
||||
bar.update(tensor.n_bytes)
|
||||
|
||||
writer.close()
|
||||
|
|
|
|||
|
|
@ -672,10 +672,11 @@ class TensorNameMap:
|
|||
),
|
||||
|
||||
MODEL_TENSOR.SSM_IN: (
|
||||
"model.layers.{bid}.in_proj", # mamba-hf
|
||||
"backbone.layers.{bid}.mixer.in_proj", # mamba
|
||||
"model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid
|
||||
"model.layers.layers.{bid}.mixer.in_proj", # plamo2
|
||||
"model.layers.{bid}.in_proj", # mamba-hf
|
||||
"backbone.layers.{bid}.mixer.in_proj", # mamba
|
||||
"model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid
|
||||
"model.layers.layers.{bid}.mixer.in_proj", # plamo2
|
||||
"model.layers.{bid}.linear_attn.in_proj_qkvz", # qwen3next
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_CONV1D: (
|
||||
|
|
@ -683,6 +684,7 @@ class TensorNameMap:
|
|||
"backbone.layers.{bid}.mixer.conv1d", # mamba
|
||||
"model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 granite-hybrid
|
||||
"model.layers.layers.{bid}.mixer.conv1d", # plamo2
|
||||
"model.layers.{bid}.linear_attn.conv1d", # qwen3next
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_X: (
|
||||
|
|
@ -697,6 +699,7 @@ class TensorNameMap:
|
|||
"backbone.layers.{bid}.mixer.dt_proj", # mamba
|
||||
"model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid
|
||||
"model.layers.layers.{bid}.mixer.dt_proj", # plamo2
|
||||
"model.layers.{bid}.linear_attn.dt_proj", # qwen3next
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_DT_NORM: (
|
||||
|
|
@ -709,6 +712,7 @@ class TensorNameMap:
|
|||
"backbone.layers.{bid}.mixer.A_log", # mamba
|
||||
"model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid
|
||||
"model.layers.layers.{bid}.mixer.A_log", # plamo2
|
||||
"model.layers.{bid}.linear_attn.A_log", # qwen3next
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_B_NORM: (
|
||||
|
|
@ -731,17 +735,23 @@ class TensorNameMap:
|
|||
),
|
||||
|
||||
MODEL_TENSOR.SSM_NORM: (
|
||||
"model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid
|
||||
"backbone.layers.{bid}.mixer.norm", # mamba2
|
||||
"model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid
|
||||
"model.layers.{bid}.linear_attn.norm", # qwen3next
|
||||
"backbone.layers.{bid}.mixer.norm", # mamba2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_OUT: (
|
||||
"model.layers.{bid}.out_proj", # mamba-hf
|
||||
"backbone.layers.{bid}.mixer.out_proj", # mamba
|
||||
"model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 granite-hybrid
|
||||
"model.layers.{bid}.linear_attn.out_proj", # qwen3next
|
||||
"model.layers.layers.{bid}.mixer.out_proj", # plamo2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_BETA_ALPHA: (
|
||||
"model.layers.{bid}.linear_attn.in_proj_ba", # qwen3next
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_W0: (
|
||||
"model.layers.{bid}.attention.w0", # rwkv7
|
||||
),
|
||||
|
|
|
|||
|
|
@ -114,6 +114,7 @@ add_library(llama
|
|||
models/qwen3vl.cpp
|
||||
models/qwen3vl-moe.cpp
|
||||
models/qwen3moe.cpp
|
||||
models/qwen3next.cpp
|
||||
models/refact.cpp
|
||||
models/rnd1.cpp
|
||||
models/rwkv6-base.cpp
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
|
||||
{ LLM_ARCH_QWEN3, "qwen3" },
|
||||
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
|
||||
{ LLM_ARCH_QWEN3NEXT, "qwen3next" },
|
||||
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
|
||||
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
|
||||
{ LLM_ARCH_PHI2, "phi2" },
|
||||
|
|
@ -829,6 +830,38 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_QWEN3NEXT,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
||||
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
||||
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
||||
{ LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" },
|
||||
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
||||
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
|
||||
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_QWEN3VL,
|
||||
{
|
||||
|
|
@ -2556,6 +2589,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|||
{LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_SSM_BETA_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
|
|
@ -2754,6 +2788,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
|
|||
case LLM_ARCH_LFM2:
|
||||
case LLM_ARCH_LFM2MOE:
|
||||
case LLM_ARCH_NEMOTRON_H:
|
||||
case LLM_ARCH_QWEN3NEXT:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ enum llm_arch {
|
|||
LLM_ARCH_QWEN2VL,
|
||||
LLM_ARCH_QWEN3,
|
||||
LLM_ARCH_QWEN3MOE,
|
||||
LLM_ARCH_QWEN3NEXT,
|
||||
LLM_ARCH_QWEN3VL,
|
||||
LLM_ARCH_QWEN3VLMOE,
|
||||
LLM_ARCH_PHI2,
|
||||
|
|
@ -381,6 +382,7 @@ enum llm_tensor {
|
|||
LLM_TENSOR_SSM_D,
|
||||
LLM_TENSOR_SSM_NORM,
|
||||
LLM_TENSOR_SSM_OUT,
|
||||
LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next
|
||||
LLM_TENSOR_TIME_MIX_W0,
|
||||
LLM_TENSOR_TIME_MIX_W1,
|
||||
LLM_TENSOR_TIME_MIX_W2,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#include "llama-context.h"
|
||||
|
||||
#include "llama-arch.h"
|
||||
#include "llama-impl.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-io.h"
|
||||
|
|
@ -299,7 +300,7 @@ llama_context::llama_context(
|
|||
|
||||
cross.v_embd.clear();
|
||||
|
||||
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
// avoid reserving graphs with zero outputs - assume one output per sequence
|
||||
|
|
@ -542,7 +543,7 @@ bool llama_context::memory_update(bool optimize) {
|
|||
throw std::runtime_error("failed to initialize memory context");
|
||||
}
|
||||
|
||||
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
|
|
@ -1386,6 +1387,9 @@ void llama_context::output_reorder() {
|
|||
//
|
||||
|
||||
uint32_t llama_context::graph_max_nodes() const {
|
||||
if (model.arch == LLM_ARCH_QWEN3NEXT) {
|
||||
return std::max<uint32_t>(8192u, 32u*model.n_tensors());
|
||||
}
|
||||
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
// bump if necessary
|
||||
#define LLAMA_MAX_LAYERS 512
|
||||
#define LLAMA_MAX_EXPERTS 384 // Kimi-K2
|
||||
#define LLAMA_MAX_EXPERTS 512 // Qwen3 Next
|
||||
|
||||
enum llama_expert_gating_func_type {
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-mmap.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-cparams.h"
|
||||
#include "llama-model-loader.h"
|
||||
|
||||
|
|
@ -2225,6 +2224,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_QWEN3NEXT:
|
||||
{
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
// Load linear attention (gated delta net) parameters
|
||||
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
|
||||
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
|
||||
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
|
||||
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
|
||||
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
|
||||
|
||||
// Mark recurrent layers (linear attention layers)
|
||||
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
|
||||
hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0); // TODO: extract the magic 4 from "full_attention_interval"
|
||||
}
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 80: type = LLM_TYPE_80B_A3B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: throw std::runtime_error("unsupported model architecture");
|
||||
}
|
||||
|
||||
|
|
@ -6415,6 +6437,74 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_QWEN3NEXT:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
|
||||
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
||||
|
||||
// Calculate dimensions from hyperparameters
|
||||
const int64_t head_k_dim = hparams.ssm_d_state;
|
||||
const int64_t head_v_dim = hparams.ssm_d_state;
|
||||
const int64_t n_k_heads = hparams.ssm_n_group;
|
||||
const int64_t n_v_heads = hparams.ssm_dt_rank;
|
||||
const int64_t key_dim = head_k_dim * n_k_heads;
|
||||
const int64_t value_dim = head_v_dim * n_v_heads;
|
||||
const int64_t conv_dim = key_dim * 2 + value_dim;
|
||||
|
||||
// Calculate projection sizes
|
||||
const int64_t qkvz_dim = key_dim * 2 + value_dim * 2;
|
||||
const int64_t ba_dim = n_v_heads * 2;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
|
||||
|
||||
if (!hparams.is_recurrent(i)) {
|
||||
// Attention layers
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
|
||||
|
||||
// Q/K normalization for attention layers
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||
} else {
|
||||
// Linear attention (gated delta net) specific tensors
|
||||
// Create tensors with calculated dimensions
|
||||
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, 0);
|
||||
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
|
||||
layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0);
|
||||
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), { hparams.ssm_dt_rank }, 0);
|
||||
layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0);
|
||||
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
|
||||
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0);
|
||||
}
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
|
||||
|
||||
// Shared experts
|
||||
layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0);
|
||||
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0);
|
||||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0);
|
||||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
|
|
@ -6685,6 +6775,7 @@ void llama_model::print_info() const {
|
|||
arch == LLM_ARCH_FALCON_H1 ||
|
||||
arch == LLM_ARCH_PLAMO2 ||
|
||||
arch == LLM_ARCH_GRANITE_HYBRID ||
|
||||
arch == LLM_ARCH_QWEN3NEXT ||
|
||||
arch == LLM_ARCH_NEMOTRON_H) {
|
||||
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
|
||||
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
|
||||
|
|
@ -7426,7 +7517,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
case LLM_ARCH_PANGU_EMBED:
|
||||
{
|
||||
llm = std::make_unique<llm_build_pangu_embedded>(*this, params);
|
||||
}break;
|
||||
} break;
|
||||
case LLM_ARCH_QWEN3NEXT:
|
||||
{
|
||||
llm = std::make_unique<llm_build_qwen3next>(*this, params);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
|
@ -7653,6 +7748,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_COGVLM:
|
||||
case LLM_ARCH_PANGU_EMBED:
|
||||
case LLM_ARCH_AFMOE:
|
||||
case LLM_ARCH_QWEN3NEXT:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
|
|
|
|||
|
|
@ -113,6 +113,7 @@ enum llm_type {
|
|||
LLM_TYPE_16B_A1B,
|
||||
LLM_TYPE_21B_A3B, // Ernie MoE small
|
||||
LLM_TYPE_30B_A3B,
|
||||
LLM_TYPE_80B_A3B, // Qwen3 Next
|
||||
LLM_TYPE_100B_A6B,
|
||||
LLM_TYPE_106B_A12B, // GLM-4.5-Air
|
||||
LLM_TYPE_230B_A10B, // Minimax M2
|
||||
|
|
@ -309,6 +310,9 @@ struct llama_layer {
|
|||
struct ggml_tensor * ssm_conv1d_b = nullptr;
|
||||
struct ggml_tensor * ssm_dt_b = nullptr;
|
||||
|
||||
// qwen3next
|
||||
struct ggml_tensor * ssm_beta_alpha = nullptr;
|
||||
|
||||
// rwkv
|
||||
struct ggml_tensor * time_mix_w1 = nullptr;
|
||||
struct ggml_tensor * time_mix_w2 = nullptr;
|
||||
|
|
|
|||
|
|
@ -681,7 +681,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
}
|
||||
LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
|
||||
continue;
|
||||
} else if (remapped_name != it.first) {
|
||||
}
|
||||
|
||||
if (remapped_name != it.first) {
|
||||
ggml_set_name(it.second.tensor, remapped_name.c_str());
|
||||
LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
|
||||
}
|
||||
|
|
@ -726,13 +728,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
{
|
||||
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
|
||||
// attention layers have a non-zero number of kv heads
|
||||
int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
|
||||
int32_t n_layer_attn = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
|
||||
if (llama_model_has_encoder(&model)) {
|
||||
// now n_attn_layer is the number of attention layers in the encoder
|
||||
// now n_layer_attn is the number of attention layers in the encoder
|
||||
// for each decoder block, there are 2 attention layers
|
||||
n_attn_layer += 2 * model.hparams.dec_n_layer;
|
||||
n_layer_attn += 2 * model.hparams.dec_n_layer;
|
||||
}
|
||||
GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
|
||||
|
||||
// note: for linear-attention models (such as Qwen3 Next) this is the number of linear layers
|
||||
const int32_t n_layer_recr = std::count(model.hparams.recurrent_layer_arr.begin(), model.hparams.recurrent_layer_arr.end(), true);
|
||||
|
||||
LLAMA_LOG_INFO("%s: n_layer_attn = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_attn, n_layer_recr, pruned_attention_w);
|
||||
|
||||
GGML_ASSERT((qs.n_attention_wv == n_layer_attn - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected");
|
||||
}
|
||||
|
||||
size_t total_size_org = 0;
|
||||
|
|
|
|||
|
|
@ -2,8 +2,9 @@
|
|||
|
||||
#include "../llama-model.h"
|
||||
#include "../llama-graph.h"
|
||||
#include "../llama-memory-recurrent.h"
|
||||
|
||||
// TODO: remove in follow-up PR - move to .cpp files
|
||||
#include "../llama-memory-recurrent.h"
|
||||
#include <cmath>
|
||||
|
||||
struct llm_graph_context_mamba : public llm_graph_context {
|
||||
|
|
@ -421,7 +422,56 @@ struct llm_build_qwen3vl : public llm_graph_context {
|
|||
struct llm_build_qwen3vlmoe : public llm_graph_context {
|
||||
llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
struct llm_build_qwen3next : public llm_graph_context_mamba {
|
||||
llm_build_qwen3next(const llama_model & model, const llm_graph_params & params);
|
||||
private:
|
||||
ggml_tensor * build_layer_attn(
|
||||
llm_graph_input_attn_kv * inp_attn,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * inp_pos,
|
||||
int il);
|
||||
|
||||
ggml_tensor * build_layer_attn_linear(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
int il);
|
||||
|
||||
ggml_tensor * build_layer_ffn(
|
||||
ggml_tensor * cur,
|
||||
int il);
|
||||
|
||||
ggml_tensor * build_delta_net_recurrent(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
int il);
|
||||
|
||||
ggml_tensor * build_delta_net_chunking(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
int il);
|
||||
|
||||
ggml_tensor * build_norm_gated(
|
||||
ggml_tensor * input,
|
||||
ggml_tensor * weights,
|
||||
ggml_tensor * gate,
|
||||
int layer);
|
||||
|
||||
const llama_model & model;
|
||||
};
|
||||
|
||||
struct llm_build_qwen : public llm_graph_context {
|
||||
llm_build_qwen(const llama_model & model, const llm_graph_params & params);
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -196,7 +196,7 @@ if (NOT WIN32)
|
|||
llama_build_and_test(test-arg-parser.cpp)
|
||||
endif()
|
||||
|
||||
if (NOT LLAMA_SANITIZE_ADDRESS)
|
||||
if (NOT LLAMA_SANITIZE_ADDRESS AND NOT GGML_SCHED_NO_REALLOC)
|
||||
# TODO: repair known memory leaks
|
||||
llama_build_and_test(test-opt.cpp)
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -1446,14 +1446,14 @@ struct test_case {
|
|||
const uint64_t target_flops_cpu = 8ULL * GFLOP;
|
||||
const uint64_t target_flops_gpu = 100ULL * GFLOP;
|
||||
uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu;
|
||||
n_runs = std::min<int>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1;
|
||||
n_runs = (int)std::min<int64_t>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1;
|
||||
} else {
|
||||
// based on memory size
|
||||
const size_t GB = 1ULL << 30;
|
||||
const size_t target_size_cpu = 8 * GB;
|
||||
const size_t target_size_gpu = 32 * GB;
|
||||
size_t target_size = is_cpu ? target_size_cpu : target_size_gpu;
|
||||
n_runs = std::min<int>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1;
|
||||
n_runs = (int)std::min<int64_t>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1;
|
||||
}
|
||||
|
||||
// duplicate the op
|
||||
|
|
@ -8043,7 +8043,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
}
|
||||
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
|
||||
for (auto k : {1, 10, 40}) {
|
||||
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2, 1, 1, 1}, 1));
|
||||
for (auto k : {1, 10, 40, 400}) {
|
||||
for (auto nrows : {1, 16}) {
|
||||
for (auto cols : {k, 1000, 65000, 200000}) {
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {cols, nrows, 1, 1}, k));
|
||||
|
|
|
|||
|
|
@ -1339,6 +1339,32 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"literal string with escapes",
|
||||
R"""({
|
||||
"properties": {
|
||||
"code": {
|
||||
"const": " \r \n \" \\ ",
|
||||
"description": "Generated code",
|
||||
"title": "Code",
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"code"
|
||||
],
|
||||
"title": "DecoderResponse",
|
||||
"type": "object"
|
||||
})""",
|
||||
R"""(
|
||||
code ::= "\" \\r \\n \\\" \\\\ \"" space
|
||||
code-kv ::= "\"code\"" space ":" space code
|
||||
root ::= "{" space code-kv "}" space
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@ set(TARGET_SRCS
|
|||
server-queue.h
|
||||
server-common.cpp
|
||||
server-common.h
|
||||
server-context.cpp
|
||||
server-context.h
|
||||
)
|
||||
set(PUBLIC_ASSETS
|
||||
index.html.gz
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp.
|
|||
**Features:**
|
||||
* LLM inference of F16 and quantized models on GPU and CPU
|
||||
* [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes
|
||||
* [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) compatible chat completions
|
||||
* Reranking endpoint (https://github.com/ggml-org/llama.cpp/pull/9510)
|
||||
* Parallel decoding with multi-user support
|
||||
* Continuous batching
|
||||
|
|
@ -1352,6 +1353,77 @@ See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-r
|
|||
}'
|
||||
```
|
||||
|
||||
### POST `/v1/messages`: Anthropic-compatible Messages API
|
||||
|
||||
Given a list of `messages`, returns the assistant's response. Streaming is supported via Server-Sent Events. While no strong claims of compatibility with the Anthropic API spec are made, in our experience it suffices to support many apps.
|
||||
|
||||
*Options:*
|
||||
|
||||
See [Anthropic Messages API documentation](https://docs.anthropic.com/en/api/messages). Tool use requires `--jinja` flag.
|
||||
|
||||
`model`: Model identifier (required)
|
||||
|
||||
`messages`: Array of message objects with `role` and `content` (required)
|
||||
|
||||
`max_tokens`: Maximum tokens to generate (default: 4096)
|
||||
|
||||
`system`: System prompt as string or array of content blocks
|
||||
|
||||
`temperature`: Sampling temperature 0-1 (default: 1.0)
|
||||
|
||||
`top_p`: Nucleus sampling (default: 1.0)
|
||||
|
||||
`top_k`: Top-k sampling
|
||||
|
||||
`stop_sequences`: Array of stop sequences
|
||||
|
||||
`stream`: Enable streaming (default: false)
|
||||
|
||||
`tools`: Array of tool definitions (requires `--jinja`)
|
||||
|
||||
`tool_choice`: Tool selection mode (`{"type": "auto"}`, `{"type": "any"}`, or `{"type": "tool", "name": "..."}`)
|
||||
|
||||
*Examples:*
|
||||
|
||||
```shell
|
||||
curl http://localhost:8080/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-api-key: your-api-key" \
|
||||
-d '{
|
||||
"model": "gpt-4",
|
||||
"max_tokens": 1024,
|
||||
"system": "You are a helpful assistant.",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello!"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
### POST `/v1/messages/count_tokens`: Token Counting
|
||||
|
||||
Counts the number of tokens in a request without generating a response.
|
||||
|
||||
Accepts the same parameters as `/v1/messages`. The `max_tokens` parameter is not required.
|
||||
|
||||
*Example:*
|
||||
|
||||
```shell
|
||||
curl http://localhost:8080/v1/messages/count_tokens \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello!"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
*Response:*
|
||||
|
||||
```json
|
||||
{"input_tokens": 10}
|
||||
```
|
||||
|
||||
## More examples
|
||||
|
||||
### Interactive mode
|
||||
|
|
|
|||
|
|
@ -257,9 +257,9 @@ const STRING_FORMAT_RULES = {
|
|||
const RESERVED_NAMES = {'root': true, ...PRIMITIVE_RULES, ...STRING_FORMAT_RULES};
|
||||
|
||||
const INVALID_RULE_CHARS_RE = /[^\dA-Za-z-]+/g;
|
||||
const GRAMMAR_LITERAL_ESCAPE_RE = /[\n\r"]/g;
|
||||
const GRAMMAR_LITERAL_ESCAPE_RE = /[\n\r"\\]/g;
|
||||
const GRAMMAR_RANGE_LITERAL_ESCAPE_RE = /[\n\r"\]\-\\]/g;
|
||||
const GRAMMAR_LITERAL_ESCAPES = { '\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]' };
|
||||
const GRAMMAR_LITERAL_ESCAPES = { '\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\' };
|
||||
|
||||
const NON_LITERAL_SET = new Set('|.()[]{}*+?');
|
||||
const ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = new Set('^$.[]()|{}*+?');
|
||||
|
|
|
|||
|
|
@ -725,7 +725,6 @@ std::vector<server_tokens> tokenize_input_prompts(const llama_vocab * vocab, mtm
|
|||
return result;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// OAI utils
|
||||
//
|
||||
|
|
@ -1048,6 +1047,222 @@ json oaicompat_chat_params_parse(
|
|||
return llama_params;
|
||||
}
|
||||
|
||||
json convert_anthropic_to_oai(const json & body) {
|
||||
json oai_body;
|
||||
|
||||
// Convert system prompt
|
||||
json oai_messages = json::array();
|
||||
auto system_param = json_value(body, "system", json());
|
||||
if (!system_param.is_null()) {
|
||||
std::string system_content;
|
||||
|
||||
if (system_param.is_string()) {
|
||||
system_content = system_param.get<std::string>();
|
||||
} else if (system_param.is_array()) {
|
||||
for (const auto & block : system_param) {
|
||||
if (json_value(block, "type", std::string()) == "text") {
|
||||
system_content += json_value(block, "text", std::string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
oai_messages.push_back({
|
||||
{"role", "system"},
|
||||
{"content", system_content}
|
||||
});
|
||||
}
|
||||
|
||||
// Convert messages
|
||||
if (!body.contains("messages")) {
|
||||
throw std::runtime_error("'messages' is required");
|
||||
}
|
||||
const json & messages = body.at("messages");
|
||||
if (messages.is_array()) {
|
||||
for (const auto & msg : messages) {
|
||||
std::string role = json_value(msg, "role", std::string());
|
||||
|
||||
if (!msg.contains("content")) {
|
||||
if (role == "assistant") {
|
||||
continue;
|
||||
}
|
||||
oai_messages.push_back(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
const json & content = msg.at("content");
|
||||
|
||||
if (content.is_string()) {
|
||||
oai_messages.push_back(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!content.is_array()) {
|
||||
oai_messages.push_back(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
json tool_calls = json::array();
|
||||
json converted_content = json::array();
|
||||
json tool_results = json::array();
|
||||
bool has_tool_calls = false;
|
||||
|
||||
for (const auto & block : content) {
|
||||
std::string type = json_value(block, "type", std::string());
|
||||
|
||||
if (type == "text") {
|
||||
converted_content.push_back(block);
|
||||
} else if (type == "image") {
|
||||
json source = json_value(block, "source", json::object());
|
||||
std::string source_type = json_value(source, "type", std::string());
|
||||
|
||||
if (source_type == "base64") {
|
||||
std::string media_type = json_value(source, "media_type", std::string("image/jpeg"));
|
||||
std::string data = json_value(source, "data", std::string());
|
||||
std::ostringstream ss;
|
||||
ss << "data:" << media_type << ";base64," << data;
|
||||
|
||||
converted_content.push_back({
|
||||
{"type", "image_url"},
|
||||
{"image_url", {
|
||||
{"url", ss.str()}
|
||||
}}
|
||||
});
|
||||
} else if (source_type == "url") {
|
||||
std::string url = json_value(source, "url", std::string());
|
||||
converted_content.push_back({
|
||||
{"type", "image_url"},
|
||||
{"image_url", {
|
||||
{"url", url}
|
||||
}}
|
||||
});
|
||||
}
|
||||
} else if (type == "tool_use") {
|
||||
tool_calls.push_back({
|
||||
{"id", json_value(block, "id", std::string())},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", json_value(block, "name", std::string())},
|
||||
{"arguments", json_value(block, "input", json::object()).dump()}
|
||||
}}
|
||||
});
|
||||
has_tool_calls = true;
|
||||
} else if (type == "tool_result") {
|
||||
std::string tool_use_id = json_value(block, "tool_use_id", std::string());
|
||||
|
||||
auto result_content = json_value(block, "content", json());
|
||||
std::string result_text;
|
||||
if (result_content.is_string()) {
|
||||
result_text = result_content.get<std::string>();
|
||||
} else if (result_content.is_array()) {
|
||||
for (const auto & c : result_content) {
|
||||
if (json_value(c, "type", std::string()) == "text") {
|
||||
result_text += json_value(c, "text", std::string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tool_results.push_back({
|
||||
{"role", "tool"},
|
||||
{"tool_call_id", tool_use_id},
|
||||
{"content", result_text}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (!converted_content.empty() || has_tool_calls) {
|
||||
json new_msg = {{"role", role}};
|
||||
if (!converted_content.empty()) {
|
||||
new_msg["content"] = converted_content;
|
||||
} else if (has_tool_calls) {
|
||||
new_msg["content"] = "";
|
||||
}
|
||||
if (!tool_calls.empty()) {
|
||||
new_msg["tool_calls"] = tool_calls;
|
||||
}
|
||||
oai_messages.push_back(new_msg);
|
||||
}
|
||||
|
||||
for (const auto & tool_msg : tool_results) {
|
||||
oai_messages.push_back(tool_msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
oai_body["messages"] = oai_messages;
|
||||
|
||||
// Convert tools
|
||||
if (body.contains("tools")) {
|
||||
const json & tools = body.at("tools");
|
||||
if (tools.is_array()) {
|
||||
json oai_tools = json::array();
|
||||
for (const auto & tool : tools) {
|
||||
oai_tools.push_back({
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", json_value(tool, "name", std::string())},
|
||||
{"description", json_value(tool, "description", std::string())},
|
||||
{"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()}
|
||||
}}
|
||||
});
|
||||
}
|
||||
oai_body["tools"] = oai_tools;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert tool_choice
|
||||
if (body.contains("tool_choice")) {
|
||||
const json & tc = body.at("tool_choice");
|
||||
if (tc.is_object()) {
|
||||
std::string type = json_value(tc, "type", std::string());
|
||||
if (type == "auto") {
|
||||
oai_body["tool_choice"] = "auto";
|
||||
} else if (type == "any" || type == "tool") {
|
||||
oai_body["tool_choice"] = "required";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert stop_sequences to stop
|
||||
if (body.contains("stop_sequences")) {
|
||||
oai_body["stop"] = body.at("stop_sequences");
|
||||
}
|
||||
|
||||
// Handle max_tokens (required in Anthropic, but we're permissive)
|
||||
if (body.contains("max_tokens")) {
|
||||
oai_body["max_tokens"] = body.at("max_tokens");
|
||||
} else {
|
||||
oai_body["max_tokens"] = 4096;
|
||||
}
|
||||
|
||||
// Pass through common params
|
||||
for (const auto & key : {"temperature", "top_p", "top_k", "stream"}) {
|
||||
if (body.contains(key)) {
|
||||
oai_body[key] = body.at(key);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle Anthropic-specific thinking param
|
||||
if (body.contains("thinking")) {
|
||||
json thinking = json_value(body, "thinking", json::object());
|
||||
std::string thinking_type = json_value(thinking, "type", std::string());
|
||||
if (thinking_type == "enabled") {
|
||||
int budget_tokens = json_value(thinking, "budget_tokens", 10000);
|
||||
oai_body["thinking_budget_tokens"] = budget_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle Anthropic-specific metadata param
|
||||
if (body.contains("metadata")) {
|
||||
json metadata = json_value(body, "metadata", json::object());
|
||||
std::string user_id = json_value(metadata, "user_id", std::string());
|
||||
if (!user_id.empty()) {
|
||||
oai_body["__metadata_user_id"] = user_id;
|
||||
}
|
||||
}
|
||||
|
||||
return oai_body;
|
||||
}
|
||||
|
||||
json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64) {
|
||||
json data = json::array();
|
||||
int32_t n_tokens = 0;
|
||||
|
|
@ -1211,7 +1426,7 @@ std::string tokens_to_output_formatted_string(const llama_context * ctx, const l
|
|||
|
||||
// format server-sent event (SSE), return the formatted string to send
|
||||
// note: if data is a json array, it will be sent as multiple events, one per item
|
||||
std::string format_sse(const json & data) {
|
||||
std::string format_oai_sse(const json & data) {
|
||||
std::ostringstream ss;
|
||||
auto send_single = [&ss](const json & data) {
|
||||
ss << "data: " <<
|
||||
|
|
@ -1230,6 +1445,29 @@ std::string format_sse(const json & data) {
|
|||
return ss.str();
|
||||
}
|
||||
|
||||
std::string format_anthropic_sse(const json & data) {
|
||||
std::ostringstream ss;
|
||||
|
||||
auto send_event = [&ss](const json & event_obj) {
|
||||
if (event_obj.contains("event") && event_obj.contains("data")) {
|
||||
ss << "event: " << event_obj.at("event").get<std::string>() << "\n";
|
||||
ss << "data: " << safe_json_to_str(event_obj.at("data")) << "\n\n";
|
||||
} else {
|
||||
ss << "data: " << safe_json_to_str(event_obj) << "\n\n";
|
||||
}
|
||||
};
|
||||
|
||||
if (data.is_array()) {
|
||||
for (const auto & event : data) {
|
||||
send_event(event);
|
||||
}
|
||||
} else {
|
||||
send_event(data);
|
||||
}
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
bool is_valid_utf8(const std::string & str) {
|
||||
const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data());
|
||||
const unsigned char* end = bytes + str.length();
|
||||
|
|
|
|||
|
|
@ -294,6 +294,9 @@ json oaicompat_chat_params_parse(
|
|||
const oaicompat_parser_options & opt,
|
||||
std::vector<raw_buffer> & out_files);
|
||||
|
||||
// convert Anthropic Messages API format to OpenAI Chat Completions API format
|
||||
json convert_anthropic_to_oai(const json & body);
|
||||
|
||||
// TODO: move it to server-task.cpp
|
||||
json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false);
|
||||
|
||||
|
|
@ -320,7 +323,10 @@ std::string tokens_to_output_formatted_string(const llama_context * ctx, const l
|
|||
|
||||
// format server-sent event (SSE), return the formatted string to send
|
||||
// note: if data is a json array, it will be sent as multiple events, one per item
|
||||
std::string format_sse(const json & data);
|
||||
std::string format_oai_sse(const json & data);
|
||||
|
||||
// format Anthropic-style SSE with event types
|
||||
std::string format_anthropic_sse(const json & data);
|
||||
|
||||
bool is_valid_utf8(const std::string & str);
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,83 @@
|
|||
#include "server-http.h"
|
||||
#include "server-task.h"
|
||||
#include "server-queue.h"
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
|
||||
struct server_context_impl; // private implementation
|
||||
|
||||
struct server_context {
|
||||
std::unique_ptr<server_context_impl> impl;
|
||||
|
||||
server_context();
|
||||
~server_context();
|
||||
|
||||
// initialize slots and server-related data
|
||||
void init();
|
||||
|
||||
// load the model and initialize llama_context
|
||||
// returns true on success
|
||||
bool load_model(const common_params & params);
|
||||
|
||||
// this function will block main thread until termination
|
||||
void start_loop();
|
||||
|
||||
// terminate main loop (will unblock start_loop)
|
||||
void terminate();
|
||||
|
||||
// get the underlaying llama_context
|
||||
llama_context * get_llama_context() const;
|
||||
|
||||
// get the underlaying queue_tasks and queue_results
|
||||
// used by CLI application
|
||||
std::pair<server_queue &, server_response &> get_queues();
|
||||
};
|
||||
|
||||
|
||||
// forward declarations
|
||||
struct server_res_generator;
|
||||
|
||||
struct server_routes {
|
||||
server_routes(const common_params & params, server_context & ctx_server, std::function<bool()> is_ready = []() { return true; })
|
||||
: params(params), ctx_server(*ctx_server.impl), is_ready(is_ready) {
|
||||
init_routes();
|
||||
}
|
||||
|
||||
void init_routes();
|
||||
// handlers using lambda function, so that they can capture `this` without `std::bind`
|
||||
server_http_context::handler_t get_health;
|
||||
server_http_context::handler_t get_metrics;
|
||||
server_http_context::handler_t get_slots;
|
||||
server_http_context::handler_t post_slots;
|
||||
server_http_context::handler_t get_props;
|
||||
server_http_context::handler_t post_props;
|
||||
server_http_context::handler_t get_api_show;
|
||||
server_http_context::handler_t post_infill;
|
||||
server_http_context::handler_t post_completions;
|
||||
server_http_context::handler_t post_completions_oai;
|
||||
server_http_context::handler_t post_chat_completions;
|
||||
server_http_context::handler_t post_anthropic_messages;
|
||||
server_http_context::handler_t post_anthropic_count_tokens;
|
||||
server_http_context::handler_t post_apply_template;
|
||||
server_http_context::handler_t get_models;
|
||||
server_http_context::handler_t post_tokenize;
|
||||
server_http_context::handler_t post_detokenize;
|
||||
server_http_context::handler_t post_embeddings;
|
||||
server_http_context::handler_t post_embeddings_oai;
|
||||
server_http_context::handler_t post_rerank;
|
||||
server_http_context::handler_t get_lora_adapters;
|
||||
server_http_context::handler_t post_lora_adapters;
|
||||
private:
|
||||
// TODO: move these outside of server_routes?
|
||||
std::unique_ptr<server_res_generator> handle_slots_save(const server_http_req & req, int id_slot);
|
||||
std::unique_ptr<server_res_generator> handle_slots_restore(const server_http_req & req, int id_slot);
|
||||
std::unique_ptr<server_res_generator> handle_slots_erase(const server_http_req &, int id_slot);
|
||||
std::unique_ptr<server_res_generator> handle_embeddings_impl(const server_http_req & req, task_response_type res_type);
|
||||
|
||||
const common_params & params;
|
||||
server_context_impl & ctx_server;
|
||||
std::function<bool()> is_ready;
|
||||
};
|
||||
|
|
@ -136,15 +136,22 @@ bool server_http_context::init(const common_params & params) {
|
|||
return true;
|
||||
}
|
||||
|
||||
// Check for API key in the header
|
||||
auto auth_header = req.get_header_value("Authorization");
|
||||
// Check for API key in the Authorization header
|
||||
std::string req_api_key = req.get_header_value("Authorization");
|
||||
if (req_api_key.empty()) {
|
||||
// retry with anthropic header
|
||||
req_api_key = req.get_header_value("X-Api-Key");
|
||||
}
|
||||
|
||||
// remove the "Bearer " prefix if needed
|
||||
std::string prefix = "Bearer ";
|
||||
if (auth_header.substr(0, prefix.size()) == prefix) {
|
||||
std::string received_api_key = auth_header.substr(prefix.size());
|
||||
if (std::find(api_keys.begin(), api_keys.end(), received_api_key) != api_keys.end()) {
|
||||
return true; // API key is valid
|
||||
}
|
||||
if (req_api_key.substr(0, prefix.size()) == prefix) {
|
||||
req_api_key = req_api_key.substr(prefix.size());
|
||||
}
|
||||
|
||||
// validate the API key
|
||||
if (std::find(api_keys.begin(), api_keys.end(), req_api_key) != api_keys.end()) {
|
||||
return true; // API key is valid
|
||||
}
|
||||
|
||||
// API key is invalid or not provided
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ server_task_result_ptr server_response::recv(const std::unordered_set<int> & id_
|
|||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
condition_results.wait(lock, [&]{
|
||||
if (!running) {
|
||||
RES_DBG("%s : queue result stop\n", __func__);
|
||||
RES_DBG("%s : queue result stop\n", "recv");
|
||||
std::terminate(); // we cannot return here since the caller is HTTP code
|
||||
}
|
||||
return !queue_results.empty();
|
||||
|
|
@ -266,3 +266,86 @@ void server_response::terminate() {
|
|||
running = false;
|
||||
condition_results.notify_all();
|
||||
}
|
||||
|
||||
//
|
||||
// server_response_reader
|
||||
//
|
||||
|
||||
void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
|
||||
id_tasks = server_task::get_list_id(tasks);
|
||||
queue_results.add_waiting_tasks(tasks);
|
||||
queue_tasks.post(std::move(tasks));
|
||||
}
|
||||
|
||||
bool server_response_reader::has_next() const {
|
||||
return !cancelled && received_count < id_tasks.size();
|
||||
}
|
||||
|
||||
// return nullptr if should_stop() is true before receiving a result
|
||||
// note: if one error is received, it will stop further processing and return error result
|
||||
server_task_result_ptr server_response_reader::next(const std::function<bool()> & should_stop) {
|
||||
while (true) {
|
||||
server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, polling_interval_seconds);
|
||||
if (result == nullptr) {
|
||||
// timeout, check stop condition
|
||||
if (should_stop()) {
|
||||
SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n");
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
if (result->is_error()) {
|
||||
stop(); // cancel remaining tasks
|
||||
SRV_DBG("%s", "received error result, stopping further processing\n");
|
||||
return result;
|
||||
}
|
||||
if (result->is_stop()) {
|
||||
received_count++;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// should not reach here
|
||||
}
|
||||
|
||||
server_response_reader::batch_response server_response_reader::wait_for_all(const std::function<bool()> & should_stop) {
|
||||
batch_response batch_res;
|
||||
batch_res.results.resize(id_tasks.size());
|
||||
while (has_next()) {
|
||||
auto res = next(should_stop);
|
||||
if (res == nullptr) {
|
||||
batch_res.is_terminated = true;
|
||||
return batch_res;
|
||||
}
|
||||
if (res->is_error()) {
|
||||
batch_res.error = std::move(res);
|
||||
return batch_res;
|
||||
}
|
||||
const size_t idx = res->get_index();
|
||||
GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
|
||||
GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
|
||||
batch_res.results[idx] = std::move(res);
|
||||
}
|
||||
return batch_res;
|
||||
}
|
||||
|
||||
void server_response_reader::stop() {
|
||||
queue_results.remove_waiting_task_ids(id_tasks);
|
||||
if (has_next() && !cancelled) {
|
||||
// if tasks is not finished yet, cancel them
|
||||
cancelled = true;
|
||||
std::vector<server_task> cancel_tasks;
|
||||
cancel_tasks.reserve(id_tasks.size());
|
||||
for (const auto & id_task : id_tasks) {
|
||||
SRV_WRN("cancel task, id_task = %d\n", id_task);
|
||||
server_task task(SERVER_TASK_TYPE_CANCEL);
|
||||
task.id_target = id_task;
|
||||
queue_results.remove_waiting_task_id(id_task);
|
||||
cancel_tasks.push_back(std::move(task));
|
||||
}
|
||||
// push to beginning of the queue, so it has highest priority
|
||||
queue_tasks.post(std::move(cancel_tasks), true);
|
||||
} else {
|
||||
SRV_DBG("%s", "all tasks already finished, no need to cancel\n");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -108,3 +108,39 @@ public:
|
|||
// terminate the waiting loop
|
||||
void terminate();
|
||||
};
|
||||
|
||||
// utility class to make working with server_queue and server_response easier
|
||||
// it provides a generator-like API for server responses
|
||||
// support pooling connection state and aggregating multiple results
|
||||
struct server_response_reader {
|
||||
std::unordered_set<int> id_tasks;
|
||||
server_queue & queue_tasks;
|
||||
server_response & queue_results;
|
||||
size_t received_count = 0;
|
||||
bool cancelled = false;
|
||||
int polling_interval_seconds;
|
||||
|
||||
// should_stop function will be called each polling_interval_seconds
|
||||
server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
|
||||
: queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
|
||||
~server_response_reader() {
|
||||
stop();
|
||||
}
|
||||
|
||||
void post_tasks(std::vector<server_task> && tasks);
|
||||
bool has_next() const;
|
||||
|
||||
// return nullptr if should_stop() is true before receiving a result
|
||||
// note: if one error is received, it will stop further processing and return error result
|
||||
server_task_result_ptr next(const std::function<bool()> & should_stop);
|
||||
|
||||
struct batch_response {
|
||||
bool is_terminated = false; // if true, indicates that processing was stopped before all results were received
|
||||
std::vector<server_task_result_ptr> results;
|
||||
server_task_result_ptr error; // nullptr if no error
|
||||
};
|
||||
// aggregate multiple results
|
||||
batch_response wait_for_all(const std::function<bool()> & should_stop);
|
||||
|
||||
void stop();
|
||||
};
|
||||
|
|
|
|||
|
|
@ -565,15 +565,17 @@ std::vector<unsigned char> completion_token_output::str_to_bytes(const std::stri
|
|||
// server_task_result_cmpl_final
|
||||
//
|
||||
json server_task_result_cmpl_final::to_json() {
|
||||
switch (oaicompat) {
|
||||
case OAICOMPAT_TYPE_NONE:
|
||||
switch (res_type) {
|
||||
case TASK_RESPONSE_TYPE_NONE:
|
||||
return to_json_non_oaicompat();
|
||||
case OAICOMPAT_TYPE_COMPLETION:
|
||||
case TASK_RESPONSE_TYPE_OAI_CMPL:
|
||||
return to_json_oaicompat();
|
||||
case OAICOMPAT_TYPE_CHAT:
|
||||
case TASK_RESPONSE_TYPE_OAI_CHAT:
|
||||
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
|
||||
case TASK_RESPONSE_TYPE_ANTHROPIC:
|
||||
return stream ? to_json_anthropic_stream() : to_json_anthropic();
|
||||
default:
|
||||
GGML_ASSERT(false && "Invalid oaicompat_type");
|
||||
GGML_ASSERT(false && "Invalid task_response_type");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -768,19 +770,203 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
|
|||
return deltas;
|
||||
}
|
||||
|
||||
json server_task_result_cmpl_final::to_json_anthropic() {
|
||||
std::string stop_reason = "max_tokens";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use";
|
||||
}
|
||||
|
||||
json content_blocks = json::array();
|
||||
|
||||
common_chat_msg msg;
|
||||
if (!oaicompat_msg.empty()) {
|
||||
msg = oaicompat_msg;
|
||||
} else {
|
||||
msg.role = "assistant";
|
||||
msg.content = content;
|
||||
}
|
||||
|
||||
if (!msg.content.empty()) {
|
||||
content_blocks.push_back({
|
||||
{"type", "text"},
|
||||
{"text", msg.content}
|
||||
});
|
||||
}
|
||||
|
||||
for (const auto & tool_call : msg.tool_calls) {
|
||||
json tool_use_block = {
|
||||
{"type", "tool_use"},
|
||||
{"id", tool_call.id},
|
||||
{"name", tool_call.name}
|
||||
};
|
||||
|
||||
try {
|
||||
tool_use_block["input"] = json::parse(tool_call.arguments);
|
||||
} catch (const std::exception &) {
|
||||
tool_use_block["input"] = json::object();
|
||||
}
|
||||
|
||||
content_blocks.push_back(tool_use_block);
|
||||
}
|
||||
|
||||
json res = {
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"type", "message"},
|
||||
{"role", "assistant"},
|
||||
{"content", content_blocks},
|
||||
{"model", oaicompat_model},
|
||||
{"stop_reason", stop_reason},
|
||||
{"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)},
|
||||
{"usage", {
|
||||
{"input_tokens", n_prompt_tokens},
|
||||
{"output_tokens", n_decoded}
|
||||
}}
|
||||
};
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
json server_task_result_cmpl_final::to_json_anthropic_stream() {
|
||||
json events = json::array();
|
||||
|
||||
std::string stop_reason = "max_tokens";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use";
|
||||
}
|
||||
|
||||
bool has_text = !oaicompat_msg.content.empty();
|
||||
size_t num_tool_calls = oaicompat_msg.tool_calls.size();
|
||||
|
||||
bool text_block_started = false;
|
||||
std::unordered_set<size_t> tool_calls_started;
|
||||
|
||||
for (const auto & diff : oaicompat_msg_diffs) {
|
||||
if (!diff.content_delta.empty()) {
|
||||
if (!text_block_started) {
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", 0},
|
||||
{"content_block", {
|
||||
{"type", "text"},
|
||||
{"text", ""}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
text_block_started = true;
|
||||
}
|
||||
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", 0},
|
||||
{"delta", {
|
||||
{"type", "text_delta"},
|
||||
{"text", diff.content_delta}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
if (diff.tool_call_index != std::string::npos) {
|
||||
size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index;
|
||||
|
||||
if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) {
|
||||
const auto & full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index];
|
||||
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", content_block_index},
|
||||
{"content_block", {
|
||||
{"type", "tool_use"},
|
||||
{"id", full_tool_call.id},
|
||||
{"name", full_tool_call.name}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
tool_calls_started.insert(diff.tool_call_index);
|
||||
}
|
||||
|
||||
if (!diff.tool_call_delta.arguments.empty()) {
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", content_block_index},
|
||||
{"delta", {
|
||||
{"type", "input_json_delta"},
|
||||
{"partial_json", diff.tool_call_delta.arguments}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (has_text) {
|
||||
events.push_back({
|
||||
{"event", "content_block_stop"},
|
||||
{"data", {
|
||||
{"type", "content_block_stop"},
|
||||
{"index", 0}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_tool_calls; i++) {
|
||||
size_t content_block_index = (has_text ? 1 : 0) + i;
|
||||
events.push_back({
|
||||
{"event", "content_block_stop"},
|
||||
{"data", {
|
||||
{"type", "content_block_stop"},
|
||||
{"index", content_block_index}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
events.push_back({
|
||||
{"event", "message_delta"},
|
||||
{"data", {
|
||||
{"type", "message_delta"},
|
||||
{"delta", {
|
||||
{"stop_reason", stop_reason},
|
||||
{"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)}
|
||||
}},
|
||||
{"usage", {
|
||||
{"output_tokens", n_decoded}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
|
||||
events.push_back({
|
||||
{"event", "message_stop"},
|
||||
{"data", {
|
||||
{"type", "message_stop"}
|
||||
}}
|
||||
});
|
||||
|
||||
return events;
|
||||
}
|
||||
|
||||
//
|
||||
// server_task_result_cmpl_partial
|
||||
//
|
||||
json server_task_result_cmpl_partial::to_json() {
|
||||
switch (oaicompat) {
|
||||
case OAICOMPAT_TYPE_NONE:
|
||||
switch (res_type) {
|
||||
case TASK_RESPONSE_TYPE_NONE:
|
||||
return to_json_non_oaicompat();
|
||||
case OAICOMPAT_TYPE_COMPLETION:
|
||||
case TASK_RESPONSE_TYPE_OAI_CMPL:
|
||||
return to_json_oaicompat();
|
||||
case OAICOMPAT_TYPE_CHAT:
|
||||
case TASK_RESPONSE_TYPE_OAI_CHAT:
|
||||
return to_json_oaicompat_chat();
|
||||
case TASK_RESPONSE_TYPE_ANTHROPIC:
|
||||
return to_json_anthropic();
|
||||
default:
|
||||
GGML_ASSERT(false && "Invalid oaicompat_type");
|
||||
GGML_ASSERT(false && "Invalid task_response_type");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -905,7 +1091,7 @@ json server_task_result_cmpl_partial::to_json_oaicompat_chat() {
|
|||
// server_task_result_embd
|
||||
//
|
||||
json server_task_result_embd::to_json() {
|
||||
return oaicompat == OAICOMPAT_TYPE_EMBEDDING
|
||||
return res_type == TASK_RESPONSE_TYPE_OAI_EMBD
|
||||
? to_json_oaicompat()
|
||||
: to_json_non_oaicompat();
|
||||
}
|
||||
|
|
@ -936,6 +1122,102 @@ json server_task_result_rerank::to_json() {
|
|||
};
|
||||
}
|
||||
|
||||
json server_task_result_cmpl_partial::to_json_anthropic() {
|
||||
json events = json::array();
|
||||
bool first = (n_decoded == 1);
|
||||
static bool text_block_started = false;
|
||||
|
||||
if (first) {
|
||||
text_block_started = false;
|
||||
|
||||
events.push_back({
|
||||
{"event", "message_start"},
|
||||
{"data", {
|
||||
{"type", "message_start"},
|
||||
{"message", {
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"type", "message"},
|
||||
{"role", "assistant"},
|
||||
{"content", json::array()},
|
||||
{"model", oaicompat_model},
|
||||
{"stop_reason", nullptr},
|
||||
{"stop_sequence", nullptr},
|
||||
{"usage", {
|
||||
{"input_tokens", n_prompt_tokens},
|
||||
{"output_tokens", 0}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
for (const auto & diff : oaicompat_msg_diffs) {
|
||||
if (!diff.content_delta.empty()) {
|
||||
if (!text_block_started) {
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", 0},
|
||||
{"content_block", {
|
||||
{"type", "text"},
|
||||
{"text", ""}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
text_block_started = true;
|
||||
}
|
||||
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", 0},
|
||||
{"delta", {
|
||||
{"type", "text_delta"},
|
||||
{"text", diff.content_delta}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
if (diff.tool_call_index != std::string::npos) {
|
||||
size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index;
|
||||
|
||||
if (!diff.tool_call_delta.name.empty()) {
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", content_block_index},
|
||||
{"content_block", {
|
||||
{"type", "tool_use"},
|
||||
{"id", diff.tool_call_delta.id},
|
||||
{"name", diff.tool_call_delta.name}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
if (!diff.tool_call_delta.arguments.empty()) {
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", content_block_index},
|
||||
{"delta", {
|
||||
{"type", "input_json_delta"},
|
||||
{"partial_json", diff.tool_call_delta.arguments}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return events;
|
||||
}
|
||||
|
||||
//
|
||||
// server_task_result_error
|
||||
//
|
||||
|
|
|
|||
|
|
@ -27,11 +27,12 @@ enum server_task_type {
|
|||
};
|
||||
|
||||
// TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common
|
||||
enum oaicompat_type {
|
||||
OAICOMPAT_TYPE_NONE,
|
||||
OAICOMPAT_TYPE_CHAT,
|
||||
OAICOMPAT_TYPE_COMPLETION,
|
||||
OAICOMPAT_TYPE_EMBEDDING,
|
||||
enum task_response_type {
|
||||
TASK_RESPONSE_TYPE_NONE, // llama.cpp native format
|
||||
TASK_RESPONSE_TYPE_OAI_CHAT,
|
||||
TASK_RESPONSE_TYPE_OAI_CMPL,
|
||||
TASK_RESPONSE_TYPE_OAI_EMBD,
|
||||
TASK_RESPONSE_TYPE_ANTHROPIC,
|
||||
};
|
||||
|
||||
enum stop_type {
|
||||
|
|
@ -66,9 +67,9 @@ struct task_params {
|
|||
struct common_params_sampling sampling;
|
||||
struct common_params_speculative speculative;
|
||||
|
||||
// OAI-compat fields
|
||||
// response formatting
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_syntax oaicompat_chat_syntax;
|
||||
|
|
@ -227,12 +228,12 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
|
||||
task_params generation_params;
|
||||
|
||||
// OAI-compat fields
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_msg oaicompat_msg;
|
||||
// response formatting
|
||||
bool verbose = false;
|
||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_msg oaicompat_msg;
|
||||
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||
|
||||
|
|
@ -253,6 +254,10 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
json to_json_oaicompat_chat();
|
||||
|
||||
json to_json_oaicompat_chat_stream();
|
||||
|
||||
json to_json_anthropic();
|
||||
|
||||
json to_json_anthropic_stream();
|
||||
};
|
||||
|
||||
struct server_task_result_cmpl_partial : server_task_result {
|
||||
|
|
@ -270,11 +275,11 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
result_timings timings;
|
||||
result_prompt_progress progress;
|
||||
|
||||
// OAI-compat fields
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
// response formatting
|
||||
bool verbose = false;
|
||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||
|
||||
virtual int get_index() override {
|
||||
|
|
@ -292,6 +297,8 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
json to_json_oaicompat();
|
||||
|
||||
json to_json_oaicompat_chat();
|
||||
|
||||
json to_json_anthropic();
|
||||
};
|
||||
|
||||
struct server_task_result_embd : server_task_result {
|
||||
|
|
@ -300,8 +307,8 @@ struct server_task_result_embd : server_task_result {
|
|||
|
||||
int32_t n_tokens;
|
||||
|
||||
// OAI-compat fields
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
// response formatting
|
||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -13,3 +13,9 @@ def stop_server_after_each_test():
|
|||
) # copy the set to prevent 'Set changed size during iteration'
|
||||
for server in instances:
|
||||
server.stop()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def do_something():
|
||||
# this will be run once per test session, before any tests
|
||||
ServerPreset.load_all()
|
||||
|
|
|
|||
|
|
@ -5,12 +5,6 @@ from utils import *
|
|||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def do_something():
|
||||
# this will be run once per test session, before any tests
|
||||
ServerPreset.load_all()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
|
|
|
|||
|
|
@ -0,0 +1,807 @@
|
|||
#!/usr/bin/env python3
|
||||
import pytest
|
||||
import base64
|
||||
import requests
|
||||
|
||||
from utils import *
|
||||
|
||||
server: ServerProcess
|
||||
|
||||
|
||||
def get_test_image_base64() -> str:
|
||||
"""Get a test image in base64 format"""
|
||||
# Use the same test image as test_vision_api.py
|
||||
IMG_URL = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
|
||||
response = requests.get(IMG_URL)
|
||||
response.raise_for_status()
|
||||
return base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.model_alias = "tinyllama-2-anthropic"
|
||||
server.server_port = 8082
|
||||
server.n_slots = 1
|
||||
server.n_ctx = 8192
|
||||
server.n_batch = 2048
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vision_server():
|
||||
"""Separate fixture for vision tests that require multimodal support"""
|
||||
global server
|
||||
server = ServerPreset.tinygemma3()
|
||||
server.offline = False # Allow downloading the model
|
||||
server.model_alias = "tinygemma3-anthropic"
|
||||
server.server_port = 8083 # Different port to avoid conflicts
|
||||
server.n_slots = 1
|
||||
return server
|
||||
|
||||
|
||||
# Basic message tests
|
||||
|
||||
def test_anthropic_messages_basic():
|
||||
"""Test basic Anthropic messages endpoint"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Say hello"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200, f"Expected 200, got {res.status_code}"
|
||||
assert res.body["type"] == "message", f"Expected type 'message', got {res.body.get('type')}"
|
||||
assert res.body["role"] == "assistant", f"Expected role 'assistant', got {res.body.get('role')}"
|
||||
assert "content" in res.body, "Missing 'content' field"
|
||||
assert isinstance(res.body["content"], list), "Content should be an array"
|
||||
assert len(res.body["content"]) > 0, "Content array should not be empty"
|
||||
assert res.body["content"][0]["type"] == "text", "First content block should be text"
|
||||
assert "text" in res.body["content"][0], "Text content block missing 'text' field"
|
||||
assert res.body["stop_reason"] in ["end_turn", "max_tokens"], f"Invalid stop_reason: {res.body.get('stop_reason')}"
|
||||
assert "usage" in res.body, "Missing 'usage' field"
|
||||
assert "input_tokens" in res.body["usage"], "Missing usage.input_tokens"
|
||||
assert "output_tokens" in res.body["usage"], "Missing usage.output_tokens"
|
||||
assert isinstance(res.body["usage"]["input_tokens"], int), "input_tokens should be integer"
|
||||
assert isinstance(res.body["usage"]["output_tokens"], int), "output_tokens should be integer"
|
||||
assert res.body["usage"]["output_tokens"] > 0, "Should have generated some tokens"
|
||||
# Anthropic API should NOT include timings
|
||||
assert "timings" not in res.body, "Anthropic API should not include timings field"
|
||||
|
||||
|
||||
def test_anthropic_messages_with_system():
|
||||
"""Test messages with system prompt"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"system": "You are a helpful assistant.",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
assert len(res.body["content"]) > 0
|
||||
|
||||
|
||||
def test_anthropic_messages_multipart_content():
|
||||
"""Test messages with multipart content blocks"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is"},
|
||||
{"type": "text", "text": " the answer?"}
|
||||
]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
|
||||
def test_anthropic_messages_conversation():
|
||||
"""Test multi-turn conversation"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
|
||||
# Streaming tests
|
||||
|
||||
def test_anthropic_messages_streaming():
|
||||
"""Test streaming messages"""
|
||||
server.start()
|
||||
|
||||
res = server.make_stream_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 30,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Say hello"}
|
||||
],
|
||||
"stream": True
|
||||
})
|
||||
|
||||
events = []
|
||||
for data in res:
|
||||
# Each event should have type and other fields
|
||||
assert "type" in data, f"Missing 'type' in event: {data}"
|
||||
events.append(data)
|
||||
|
||||
# Verify event sequence
|
||||
event_types = [e["type"] for e in events]
|
||||
assert "message_start" in event_types, "Missing message_start event"
|
||||
assert "content_block_start" in event_types, "Missing content_block_start event"
|
||||
assert "content_block_delta" in event_types, "Missing content_block_delta event"
|
||||
assert "content_block_stop" in event_types, "Missing content_block_stop event"
|
||||
assert "message_delta" in event_types, "Missing message_delta event"
|
||||
assert "message_stop" in event_types, "Missing message_stop event"
|
||||
|
||||
# Check message_start structure
|
||||
message_start = next(e for e in events if e["type"] == "message_start")
|
||||
assert "message" in message_start, "message_start missing 'message' field"
|
||||
assert message_start["message"]["type"] == "message"
|
||||
assert message_start["message"]["role"] == "assistant"
|
||||
assert message_start["message"]["content"] == []
|
||||
assert "usage" in message_start["message"]
|
||||
assert message_start["message"]["usage"]["input_tokens"] > 0
|
||||
|
||||
# Check content_block_start
|
||||
block_start = next(e for e in events if e["type"] == "content_block_start")
|
||||
assert "index" in block_start, "content_block_start missing 'index'"
|
||||
assert block_start["index"] == 0, "First content block should be at index 0"
|
||||
assert "content_block" in block_start
|
||||
assert block_start["content_block"]["type"] == "text"
|
||||
|
||||
# Check content_block_delta
|
||||
deltas = [e for e in events if e["type"] == "content_block_delta"]
|
||||
assert len(deltas) > 0, "Should have at least one content_block_delta"
|
||||
for delta in deltas:
|
||||
assert "index" in delta
|
||||
assert "delta" in delta
|
||||
assert delta["delta"]["type"] == "text_delta"
|
||||
assert "text" in delta["delta"]
|
||||
|
||||
# Check content_block_stop
|
||||
block_stop = next(e for e in events if e["type"] == "content_block_stop")
|
||||
assert "index" in block_stop
|
||||
assert block_stop["index"] == 0
|
||||
|
||||
# Check message_delta
|
||||
message_delta = next(e for e in events if e["type"] == "message_delta")
|
||||
assert "delta" in message_delta
|
||||
assert "stop_reason" in message_delta["delta"]
|
||||
assert message_delta["delta"]["stop_reason"] in ["end_turn", "max_tokens"]
|
||||
assert "usage" in message_delta
|
||||
assert message_delta["usage"]["output_tokens"] > 0
|
||||
|
||||
# Check message_stop
|
||||
message_stop = next(e for e in events if e["type"] == "message_stop")
|
||||
# message_stop should NOT have timings for Anthropic API
|
||||
assert "timings" not in message_stop, "Anthropic streaming should not include timings"
|
||||
|
||||
|
||||
# Token counting tests
|
||||
|
||||
def test_anthropic_count_tokens():
|
||||
"""Test token counting endpoint"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages/count_tokens", data={
|
||||
"model": "test",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello world"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert "input_tokens" in res.body
|
||||
assert isinstance(res.body["input_tokens"], int)
|
||||
assert res.body["input_tokens"] > 0
|
||||
# Should only have input_tokens, no other fields
|
||||
assert "output_tokens" not in res.body
|
||||
|
||||
|
||||
def test_anthropic_count_tokens_with_system():
|
||||
"""Test token counting with system prompt"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages/count_tokens", data={
|
||||
"model": "test",
|
||||
"system": "You are a helpful assistant.",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["input_tokens"] > 0
|
||||
|
||||
|
||||
def test_anthropic_count_tokens_no_max_tokens():
|
||||
"""Test that count_tokens doesn't require max_tokens"""
|
||||
server.start()
|
||||
|
||||
# max_tokens is NOT required for count_tokens
|
||||
res = server.make_request("POST", "/v1/messages/count_tokens", data={
|
||||
"model": "test",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert "input_tokens" in res.body
|
||||
|
||||
|
||||
# Tool use tests
|
||||
|
||||
def test_anthropic_tool_use_basic():
|
||||
"""Test basic tool use"""
|
||||
server.jinja = True
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 200,
|
||||
"tools": [{
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a location",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "City name"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}],
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather in Paris?"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
assert len(res.body["content"]) > 0
|
||||
|
||||
# Check if model used the tool (it might not always, depending on the model)
|
||||
content_types = [block.get("type") for block in res.body["content"]]
|
||||
|
||||
if "tool_use" in content_types:
|
||||
# Model used the tool
|
||||
assert res.body["stop_reason"] == "tool_use"
|
||||
|
||||
# Find the tool_use block
|
||||
tool_block = next(b for b in res.body["content"] if b.get("type") == "tool_use")
|
||||
assert "id" in tool_block
|
||||
assert "name" in tool_block
|
||||
assert tool_block["name"] == "get_weather"
|
||||
assert "input" in tool_block
|
||||
assert isinstance(tool_block["input"], dict)
|
||||
|
||||
|
||||
def test_anthropic_tool_result():
|
||||
"""Test sending tool results back
|
||||
|
||||
This test verifies that tool_result blocks are properly converted to
|
||||
role="tool" messages internally. Without proper conversion, this would
|
||||
fail with a 500 error: "unsupported content[].type" because tool_result
|
||||
blocks would remain in the user message content array.
|
||||
"""
|
||||
server.jinja = True
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 100,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "test123",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "Paris"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "test123",
|
||||
"content": "The weather is sunny, 25°C"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
# This would be 500 with the old bug where tool_result blocks weren't converted
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
# Model should respond to the tool result
|
||||
assert len(res.body["content"]) > 0
|
||||
assert res.body["content"][0]["type"] == "text"
|
||||
|
||||
|
||||
def test_anthropic_tool_result_with_text():
|
||||
"""Test tool result mixed with text content
|
||||
|
||||
This tests the edge case where a user message contains both text and
|
||||
tool_result blocks. The server must properly split these into separate
|
||||
messages: a user message with text, followed by tool messages.
|
||||
Without proper handling, this would fail with 500: "unsupported content[].type"
|
||||
"""
|
||||
server.jinja = True
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 100,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "tool_1",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "Paris"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Here are the results:"},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "tool_1",
|
||||
"content": "Sunny, 25°C"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
assert len(res.body["content"]) > 0
|
||||
|
||||
|
||||
def test_anthropic_tool_result_error():
|
||||
"""Test tool result with error flag"""
|
||||
server.jinja = True
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 100,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Get the weather"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "test123",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "InvalidCity"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "test123",
|
||||
"is_error": True,
|
||||
"content": "City not found"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
|
||||
def test_anthropic_tool_streaming():
|
||||
"""Test streaming with tool use"""
|
||||
server.jinja = True
|
||||
server.start()
|
||||
|
||||
res = server.make_stream_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 200,
|
||||
"stream": True,
|
||||
"tools": [{
|
||||
"name": "calculator",
|
||||
"description": "Calculate math",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {"type": "string"}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
}],
|
||||
"messages": [
|
||||
{"role": "user", "content": "Calculate 2+2"}
|
||||
]
|
||||
})
|
||||
|
||||
events = []
|
||||
for data in res:
|
||||
events.append(data)
|
||||
|
||||
event_types = [e["type"] for e in events]
|
||||
|
||||
# Should have basic events
|
||||
assert "message_start" in event_types
|
||||
assert "message_stop" in event_types
|
||||
|
||||
# If tool was used, check for proper tool streaming
|
||||
if any(e.get("type") == "content_block_start" and
|
||||
e.get("content_block", {}).get("type") == "tool_use"
|
||||
for e in events):
|
||||
# Find tool use block start
|
||||
tool_starts = [e for e in events if
|
||||
e.get("type") == "content_block_start" and
|
||||
e.get("content_block", {}).get("type") == "tool_use"]
|
||||
|
||||
assert len(tool_starts) > 0, "Should have tool_use content_block_start"
|
||||
|
||||
# Check index is correct (should be 0 if no text, 1 if there's text)
|
||||
tool_start = tool_starts[0]
|
||||
assert "index" in tool_start
|
||||
assert tool_start["content_block"]["type"] == "tool_use"
|
||||
assert "name" in tool_start["content_block"]
|
||||
|
||||
|
||||
# Vision/multimodal tests
|
||||
|
||||
def test_anthropic_vision_format_accepted():
|
||||
"""Test that Anthropic vision format is accepted (format validation only)"""
|
||||
server.start()
|
||||
|
||||
# Small 1x1 red PNG image in base64
|
||||
red_pixel_png = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg=="
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 10,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": red_pixel_png
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What is this?"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
# Server accepts the format but tinyllama doesn't support images
|
||||
# So it should return 500 with clear error message about missing mmproj
|
||||
assert res.status_code == 500
|
||||
assert "image input is not supported" in res.body.get("error", {}).get("message", "").lower()
|
||||
|
||||
|
||||
def test_anthropic_vision_base64_with_multimodal_model(vision_server):
|
||||
"""Test vision with base64 image using Anthropic format with multimodal model"""
|
||||
global server
|
||||
server = vision_server
|
||||
server.start()
|
||||
|
||||
# Get test image in base64 format
|
||||
image_base64 = get_test_image_base64()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 10,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": image_base64
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What is this:\n"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200, f"Expected 200, got {res.status_code}: {res.body}"
|
||||
assert res.body["type"] == "message"
|
||||
assert len(res.body["content"]) > 0
|
||||
assert res.body["content"][0]["type"] == "text"
|
||||
# The model should generate some response about the image
|
||||
assert len(res.body["content"][0]["text"]) > 0
|
||||
|
||||
|
||||
# Parameter tests
|
||||
|
||||
def test_anthropic_stop_sequences():
|
||||
"""Test stop_sequences parameter"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 100,
|
||||
"stop_sequences": ["\n", "END"],
|
||||
"messages": [
|
||||
{"role": "user", "content": "Count to 10"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
|
||||
def test_anthropic_temperature():
|
||||
"""Test temperature parameter"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"temperature": 0.5,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
|
||||
def test_anthropic_top_p():
|
||||
"""Test top_p parameter"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"top_p": 0.9,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
|
||||
def test_anthropic_top_k():
|
||||
"""Test top_k parameter (llama.cpp specific)"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"top_k": 40,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
|
||||
# Error handling tests
|
||||
|
||||
def test_anthropic_missing_messages():
|
||||
"""Test error when messages are missing"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50
|
||||
# missing "messages" field
|
||||
})
|
||||
|
||||
# Should return an error (400 or 500)
|
||||
assert res.status_code >= 400
|
||||
|
||||
|
||||
def test_anthropic_empty_messages():
|
||||
"""Test permissive handling of empty messages array"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"messages": []
|
||||
})
|
||||
|
||||
# Server is permissive and accepts empty messages (provides defaults)
|
||||
# This matches the permissive validation design choice
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
|
||||
# Content block index tests
|
||||
|
||||
def test_anthropic_streaming_content_block_indices():
|
||||
"""Test that content block indices are correct in streaming"""
|
||||
server.jinja = True
|
||||
server.start()
|
||||
|
||||
# Request that might produce both text and tool use
|
||||
res = server.make_stream_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 200,
|
||||
"stream": True,
|
||||
"tools": [{
|
||||
"name": "test_tool",
|
||||
"description": "A test tool",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param": {"type": "string"}
|
||||
},
|
||||
"required": ["param"]
|
||||
}
|
||||
}],
|
||||
"messages": [
|
||||
{"role": "user", "content": "Use the test tool"}
|
||||
]
|
||||
})
|
||||
|
||||
events = []
|
||||
for data in res:
|
||||
events.append(data)
|
||||
|
||||
# Check content_block_start events have sequential indices
|
||||
block_starts = [e for e in events if e.get("type") == "content_block_start"]
|
||||
if len(block_starts) > 1:
|
||||
# If there are multiple blocks, indices should be sequential
|
||||
indices = [e["index"] for e in block_starts]
|
||||
expected_indices = list(range(len(block_starts)))
|
||||
assert indices == expected_indices, f"Expected indices {expected_indices}, got {indices}"
|
||||
|
||||
# Check content_block_stop events match the starts
|
||||
block_stops = [e for e in events if e.get("type") == "content_block_stop"]
|
||||
start_indices = set(e["index"] for e in block_starts)
|
||||
stop_indices = set(e["index"] for e in block_stops)
|
||||
assert start_indices == stop_indices, "content_block_stop indices should match content_block_start indices"
|
||||
|
||||
|
||||
# Extended features tests
|
||||
|
||||
def test_anthropic_thinking():
|
||||
"""Test extended thinking parameter"""
|
||||
server.jinja = True
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 100,
|
||||
"thinking": {
|
||||
"type": "enabled",
|
||||
"budget_tokens": 50
|
||||
},
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is 2+2?"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
|
||||
def test_anthropic_metadata():
|
||||
"""Test metadata parameter"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"metadata": {
|
||||
"user_id": "test_user_123"
|
||||
},
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
|
||||
# Compatibility tests
|
||||
|
||||
def test_anthropic_vs_openai_different_response_format():
|
||||
"""Verify Anthropic format is different from OpenAI format"""
|
||||
server.start()
|
||||
|
||||
# Make OpenAI request
|
||||
openai_res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
})
|
||||
|
||||
# Make Anthropic request
|
||||
anthropic_res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
})
|
||||
|
||||
assert openai_res.status_code == 200
|
||||
assert anthropic_res.status_code == 200
|
||||
|
||||
# OpenAI has "object", Anthropic has "type"
|
||||
assert "object" in openai_res.body
|
||||
assert "type" in anthropic_res.body
|
||||
assert openai_res.body["object"] == "chat.completion"
|
||||
assert anthropic_res.body["type"] == "message"
|
||||
|
||||
# OpenAI has "choices", Anthropic has "content"
|
||||
assert "choices" in openai_res.body
|
||||
assert "content" in anthropic_res.body
|
||||
|
||||
# Different usage field names
|
||||
assert "prompt_tokens" in openai_res.body["usage"]
|
||||
assert "input_tokens" in anthropic_res.body["usage"]
|
||||
assert "completion_tokens" in openai_res.body["usage"]
|
||||
assert "output_tokens" in anthropic_res.body["usage"]
|
||||
|
|
@ -49,6 +49,19 @@ def test_correct_api_key():
|
|||
assert "content" in res.body
|
||||
|
||||
|
||||
def test_correct_api_key_anthropic_header():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completions", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
}, headers={
|
||||
"X-Api-Key": TEST_API_KEY,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "error" not in res.body
|
||||
assert "content" in res.body
|
||||
|
||||
|
||||
def test_openai_library_correct_api_key():
|
||||
global server
|
||||
server.start()
|
||||
|
|
|
|||
Loading…
Reference in New Issue