Merge branch 'master' of https://github.com/ggerganov/llama.cpp into cli_output

This commit is contained in:
David Baker 2026-03-11 13:58:53 +00:00
commit f60b3cc4f0
No known key found for this signature in database
GPG Key ID: 89298D31E5B7B548
190 changed files with 10901 additions and 22106 deletions

View File

@ -469,6 +469,7 @@ jobs:
cd build
export GGML_VK_VISIBLE_DEVICES=0
export GGML_VK_DISABLE_F16=1
export GGML_VK_DISABLE_COOPMAT=1
# This is using llvmpipe and runs slower than other backends
ctest -L main --verbose --timeout 4800

View File

@ -81,6 +81,8 @@ add_library(${TARGET} STATIC
preset.cpp
preset.h
regex-partial.cpp
reasoning-budget.cpp
reasoning-budget.h
regex-partial.h
sampling.cpp
sampling.h

View File

@ -2427,11 +2427,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
);
}
if (split_arg.size() == 1) {
std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), std::stoul(split_arg[0]) * 1024*1024);
std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), std::stoull(split_arg[0]) * 1024*1024);
return;
}
for (size_t i = 0; i < split_arg.size(); i++) {
params.fit_params_target[i] = std::stoul(split_arg[i]) * 1024*1024;
params.fit_params_target[i] = std::stoull(split_arg[i]) * 1024*1024;
}
}
).set_env("LLAMA_ARG_FIT_TARGET"));
@ -2913,6 +2913,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
auto parsed = json::parse(value);
for (const auto & item : parsed.items()) {
if (item.key() == "enable_thinking") {
LOG_WRN("Setting 'enable_thinking' via --chat-template-kwargs is deprecated. "
"Use --reasoning on / --reasoning off instead.\n");
}
params.default_template_kwargs[item.key()] = item.value().dump();
}
}
@ -3048,14 +3052,39 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.reasoning_format = common_reasoning_format_from_name(value);
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK"));
add_opt(common_arg(
{"-rea", "--reasoning"}, "[on|off|auto]",
"Use reasoning/thinking in the chat ('on', 'off', or 'auto', default: 'auto' (detect from template))",
[](common_params & params, const std::string & value) {
if (is_truthy(value)) {
params.enable_reasoning = 1;
params.default_template_kwargs["enable_thinking"] = "true";
} else if (is_falsey(value)) {
params.enable_reasoning = 0;
params.default_template_kwargs["enable_thinking"] = "false";
} else if (is_autoy(value)) {
params.enable_reasoning = -1;
} else {
throw std::invalid_argument(
string_format("error: unknown value for --reasoning: '%s'\n", value.c_str()));
}
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_REASONING"));
add_opt(common_arg(
{"--reasoning-budget"}, "N",
"controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)",
"token budget for thinking: -1 for unrestricted, 0 for immediate end, N>0 for token budget (default: -1)",
[](common_params & params, int value) {
if (value != 0 && value != -1) { throw std::invalid_argument("invalid value"); }
if (value < -1) { throw std::invalid_argument("invalid value"); }
params.reasoning_budget = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET"));
add_opt(common_arg(
{"--reasoning-budget-message"}, "MESSAGE",
"message injected before the end-of-thinking tag when reasoning budget is exhausted (default: none)",
[](common_params & params, const std::string & value) {
params.reasoning_budget_message = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
add_opt(common_arg(
{"--chat-template"}, "JINJA_TEMPLATE",
string_format(

View File

@ -90,7 +90,7 @@ common_peg_arena autoparser::build_parser(const templates_params & inputs) const
// pre-register a json-string rule that accepts both quote styles. This must happen
// before any call to p.json() so that all JSON parsing inherits the flexible rule.
if (tools.format.uses_python_dicts) {
p.rule("json-string", [&]() { return p.choice({ p.double_quoted_string(), p.single_quoted_string() }); });
p.rule("json-string", p.quoted_string());
}
parser_build_context ctx(p, inputs);
@ -135,7 +135,9 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co
if (thinking_forced_open || thinking_forced_closed) {
// Thinking is forced open OR forced closed with enable_thinking=true
// In both cases, expect only the closing tag (opening was in template)
return p.reasoning(p.until(end)) + end;
// However, since we might have incorrectly detected the open/close pattern,
// we admit an optional starting marker
return p.optional(p.literal(start)) + p.reasoning(p.until(end)) + end;
}
if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) {
// Standard tag-based reasoning OR tools-only mode (reasoning appears with tools)

View File

@ -6,7 +6,7 @@
#include <nlohmann/json.hpp>
using json = nlohmann::ordered_json;
using ordered_json = nlohmann::ordered_json;
static std::string_view trim_trailing_space(std::string_view sv, int max = -1) {
int count = 0;
@ -68,7 +68,7 @@ static int json_brace_depth(const std::string & s) {
// JSON-escape a string and return the inner content (without surrounding quotes).
static std::string escape_json_string_inner(const std::string & s) {
std::string escaped = json(s).dump();
std::string escaped = ordered_json(s).dump();
if (escaped.size() >= 2 && escaped.front() == '"' && escaped.back() == '"') {
return escaped.substr(1, escaped.size() - 2);
}
@ -309,7 +309,7 @@ void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
if (arg_count > 0) {
arg_entry = ",";
}
arg_entry += json(trim(node.text)).dump() + ":";
arg_entry += ordered_json(trim(node.text)).dump() + ":";
++arg_count;
auto & target = args_target();
@ -343,7 +343,7 @@ void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
// Try to parse as JSON value (number, bool, null, object, array)
try {
json parsed = json::parse(value_content);
ordered_json parsed = ordered_json::parse(value_content);
if (parsed.is_string()) {
// Don't add closing quote yet (added by arg_close) for monotonic streaming
std::string escaped = parsed.dump();
@ -408,7 +408,7 @@ void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
common_peg_parser common_chat_peg_builder::standard_constructed_tools(
const std::map<std::string, std::string> & markers,
const nlohmann::json & tools,
const ordered_json & tools,
bool parallel_tool_calls,
bool force_tool_calls) {
if (!tools.is_array() || tools.empty()) {
@ -439,7 +439,7 @@ common_peg_parser common_chat_peg_builder::standard_constructed_tools(
}
const auto & function = tool_def.at("function");
std::string name = function.at("name");
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
// Build argument parsers
auto args = eps();
@ -479,8 +479,8 @@ common_peg_parser common_chat_peg_builder::standard_constructed_tools(
// Python-style tool calls: name(arg1="value1", arg2=123)
// Used only by LFM2 for now, so we don't merge it into autoparser
common_peg_parser common_chat_peg_builder::python_style_tool_calls(
const nlohmann::json & tools,
bool parallel_tool_calls) {
const ordered_json & tools,
bool parallel_tool_calls) {
if (!tools.is_array() || tools.empty()) {
return eps();
}
@ -493,7 +493,7 @@ common_peg_parser common_chat_peg_builder::python_style_tool_calls(
}
const auto & function = tool_def.at("function");
std::string name = function.at("name");
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
auto args = eps();
if (params.contains("properties") && !params["properties"].empty()) {
@ -507,8 +507,8 @@ common_peg_parser common_chat_peg_builder::python_style_tool_calls(
common_peg_parser arg_value_parser = eps();
auto string_value_parser = choice({
literal("\"") + tool_arg_string_value(json_string_content()) + literal("\""),
literal("'") + tool_arg_string_value(json_string_content()) + literal("'")
literal("\"") + tool_arg_string_value(string_content('"')) + literal("\""),
literal("'") + tool_arg_string_value(string_content('\'')) + literal("'")
});
if (is_string_type) {
@ -555,11 +555,11 @@ static std::pair<std::string, std::string> parse_key_spec(const std::string & ke
// Mode 1: function_is_key — parse {"function_name": {...}}
common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key(
const nlohmann::json & tools,
const std::string & args_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key) {
const ordered_json & tools,
const std::string & args_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key) {
auto tool_choices = choice();
@ -569,7 +569,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key(
}
const auto & function = tool_def.at("function");
std::string name = function.at("name");
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
// Build inner object fields
std::vector<common_peg_parser> inner_fields;
@ -577,7 +577,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key(
if (!call_id_key.empty()) {
auto id_parser = atomic(
literal("\"" + call_id_key + "\"") + space() + literal(":") + space() +
literal("\"") + tool_id(json_string_content()) + literal("\"")
literal("\"") + tool_id(string_content('"')) + literal("\"")
);
inner_fields.push_back(optional(id_parser + space() + optional(literal(",") + space())));
}
@ -586,7 +586,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key(
auto gen_id_parser = atomic(
literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() +
choice({
literal("\"") + tool_id(json_string_content()) + literal("\""),
literal("\"") + tool_id(string_content('"')) + literal("\""),
tool_id(json_number())
})
);
@ -634,11 +634,11 @@ common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key(
// Mode 2: Nested keys (dot notation like "function.name")
common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
const nlohmann::json & tools,
const std::string & effective_name_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key) {
const ordered_json & tools,
const std::string & effective_name_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key) {
auto tool_choices = choice();
@ -655,7 +655,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
}
const auto & function = tool_def.at("function");
std::string name = function.at("name");
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
auto nested_name = literal("\"" + nested_name_field + "\"") + space() + literal(":") + space() +
literal("\"") + tool_name(literal(name)) + literal("\"");
@ -675,7 +675,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
if (id_spec.first.empty()) {
auto id_parser = atomic(
literal("\"" + call_id_key + "\"") + space() + literal(":") + space() +
literal("\"") + tool_id(json_string_content()) + literal("\"")
literal("\"") + tool_id(string_content('"')) + literal("\"")
);
tool_parser_body = tool_parser_body + optional(id_parser + space() + literal(",") + space());
}
@ -687,7 +687,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
auto gen_id_parser = atomic(
literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() +
choice({
literal("\"") + tool_id(json_string_content()) + literal("\""),
literal("\"") + tool_id(string_content('"')) + literal("\""),
tool_id(json_number())
})
);
@ -706,7 +706,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
// Mode 3: Flat keys with optional ID fields and parameter ordering
common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
const nlohmann::json & tools,
const ordered_json & tools,
const std::string & effective_name_key,
const std::string & effective_args_key,
const std::string & call_id_key,
@ -723,7 +723,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
}
const auto & function = tool_def.at("function");
std::string name = function.at("name");
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
auto tool_name_ = name_key_parser + space() + literal(":") + space() +
literal("\"") + tool_name(literal(name)) + literal("\"");
@ -736,7 +736,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
id_parser = atomic(
literal("\"" + call_id_key + "\"") + space() + literal(":") + space() +
choice({
literal("\"") + tool_id(json_string_content()) + literal("\""),
literal("\"") + tool_id(string_content('"')) + literal("\""),
tool_id(json_number())
})
);
@ -747,7 +747,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
gen_id_parser = atomic(
literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() +
choice({
literal("\"") + tool_id(json_string_content()) + literal("\""),
literal("\"") + tool_id(string_content('"')) + literal("\""),
tool_id(json_number())
})
);
@ -791,7 +791,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
common_peg_parser common_chat_peg_builder::standard_json_tools(
const std::string & section_start,
const std::string & section_end,
const nlohmann::json & tools,
const ordered_json & tools,
bool parallel_tool_calls,
bool force_tool_calls,
const std::string & name_key,

View File

@ -94,7 +94,7 @@ class common_chat_peg_builder : public common_peg_parser_builder {
// parameters_order: order in which JSON fields should be parsed
common_peg_parser standard_json_tools(const std::string & section_start,
const std::string & section_end,
const nlohmann::json & tools,
const nlohmann::ordered_json & tools,
bool parallel_tool_calls,
bool force_tool_calls,
const std::string & name_key = "",
@ -108,30 +108,30 @@ class common_chat_peg_builder : public common_peg_parser_builder {
// Legacy-compatible helper for building XML/tagged style tool calls
// Used by tests and manual parsers
common_peg_parser standard_constructed_tools(const std::map<std::string, std::string> & markers,
const nlohmann::json & tools,
const nlohmann::ordered_json & tools,
bool parallel_tool_calls,
bool force_tool_calls);
// Helper for Python-style function call format: name(arg1="value1", arg2=123)
// Used by LFM2 and similar templates
common_peg_parser python_style_tool_calls(const nlohmann::json & tools,
bool parallel_tool_calls);
common_peg_parser python_style_tool_calls(const nlohmann::ordered_json & tools,
bool parallel_tool_calls);
private:
// Implementation helpers for standard_json_tools — one per JSON tool call layout mode
common_peg_parser build_json_tools_function_is_key(const nlohmann::json & tools,
const std::string & args_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key);
common_peg_parser build_json_tools_function_is_key(const nlohmann::ordered_json & tools,
const std::string & args_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key);
common_peg_parser build_json_tools_nested_keys(const nlohmann::json & tools,
const std::string & effective_name_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key);
common_peg_parser build_json_tools_nested_keys(const nlohmann::ordered_json & tools,
const std::string & effective_name_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key);
common_peg_parser build_json_tools_flat_keys(const nlohmann::json & tools,
common_peg_parser build_json_tools_flat_keys(const nlohmann::ordered_json & tools,
const std::string & effective_name_key,
const std::string & effective_args_key,
const std::string & call_id_key,

View File

@ -857,7 +857,9 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = true;
data.supports_thinking = true;
data.supports_thinking = true;
data.thinking_start_tag = "[THINK]";
data.thinking_end_tag = "[/THINK]";
data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.preserved_tokens = {
@ -1165,9 +1167,11 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
const autoparser::templates_params & inputs) {
common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.thinking_start_tag = "<think>";
data.thinking_end_tag = "</think>";
data.preserved_tokens = {
"<|tool_calls_section_begin|>",
"<|tool_calls_section_end|>",
@ -1352,6 +1356,17 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
namespace workaround {
static void map_developer_role_to_system(json & messages) {
for (auto & message : messages) {
if (message.contains("role")) {
if (message["role"] == "developer") {
message["role"] = "system";
}
}
}
}
// if first message is system and template does not support it, merge it with next message
static void system_message_not_supported(json & messages) {
if (!messages.empty() && messages.front().at("role") == "system") {
@ -1429,6 +1444,10 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
params.add_bos = tmpls->add_bos;
params.add_eos = tmpls->add_eos;
if (src.find("<|channel|>") == std::string::npos) {
// map developer to system for all models except for GPT-OSS
workaround::map_developer_role_to_system(params.messages);
}
workaround::func_args_not_string(params.messages);
if (!tmpl.original_caps().supports_system_role) {
@ -1512,6 +1531,16 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
autoparser.analyze_template(tmpl);
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
if (auto_params.supports_thinking) {
auto_params.thinking_start_tag = autoparser.reasoning.start;
auto_params.thinking_end_tag = autoparser.reasoning.end;
// FORCED_OPEN and FORCED_CLOSED both put <think> in the generation prompt
// (FORCED_CLOSED forces empty <think></think> when thinking is disabled,
// but forces <think> open when thinking is enabled)
auto_params.thinking_forced_open =
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_OPEN ||
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_CLOSED;
}
return auto_params;
} catch (const std::exception & e) {
throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what());
@ -1605,8 +1634,8 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
build_chat_peg_parser([](common_chat_peg_builder & p) { return p.content(p.rest()) + p.end(); }) :
src_parser;
if (src_parser.empty()) {
LOG_WRN("No parser definition detected, assuming pure content parser.");
if (src_parser.empty()) {
LOG_DBG("No parser definition detected, assuming pure content parser.");
}
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), input.c_str());

View File

@ -213,6 +213,8 @@ struct common_chat_params {
bool grammar_lazy = false;
bool thinking_forced_open = false;
bool supports_thinking = false;
std::string thinking_start_tag; // e.g., "<think>"
std::string thinking_end_tag; // e.g., "</think>"
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;

View File

@ -235,6 +235,14 @@ struct common_params_sampling {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
// reasoning budget sampler parameters
// these are populated by the server/CLI based on chat template params
int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget
bool reasoning_budget_activate_immediately = false;
std::vector<llama_token> reasoning_budget_start; // start tag token sequence
std::vector<llama_token> reasoning_budget_end; // end tag token sequence
std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag)
bool backend_sampling = false;
bool has_logit_bias() const {
@ -536,7 +544,9 @@ struct common_params {
bool use_jinja = true; // NOLINT
bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
int enable_reasoning = -1; // -1 = auto, 0 = disable, 1 = enable
int reasoning_budget = -1;
std::string reasoning_budget_message; // message injected before end tag when budget exhausted
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time

View File

@ -7,6 +7,7 @@ struct common_http_url {
std::string user;
std::string password;
std::string host;
int port;
std::string path;
};
@ -47,6 +48,20 @@ static common_http_url common_http_parse_url(const std::string & url) {
parts.host = rest;
parts.path = "/";
}
auto colon_pos = parts.host.find(':');
if (colon_pos != std::string::npos) {
parts.port = std::stoi(parts.host.substr(colon_pos + 1));
parts.host = parts.host.substr(0, colon_pos);
} else if (parts.scheme == "http") {
parts.port = 80;
} else if (parts.scheme == "https") {
parts.port = 443;
} else {
throw std::runtime_error("unsupported URL scheme: " + parts.scheme);
}
return parts;
}
@ -68,7 +83,7 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
}
#endif
httplib::Client cli(parts.scheme + "://" + parts.host);
httplib::Client cli(parts.scheme + "://" + parts.host + ":" + std::to_string(parts.port));
if (!parts.user.empty()) {
cli.set_basic_auth(parts.user, parts.password);

View File

@ -790,7 +790,7 @@ public:
} else if (target.is_array()) {
size_t sel_index;
try {
sel_index = std::stoul(sel);
sel_index = std::stoull(sel);
} catch (const std::invalid_argument & e) {
sel_index = target.size();
}

View File

@ -658,7 +658,7 @@ struct parser_executor {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos);
}
static common_peg_parse_result handle_escape_sequence(common_peg_parse_context & ctx, size_t start, size_t & pos) {
static common_peg_parse_result handle_escape_sequence(common_peg_parse_context & ctx, size_t start, size_t & pos, const char delimiter) {
++pos; // consume '\'
if (pos >= ctx.input.size()) {
if (!ctx.is_lenient()) {
@ -667,23 +667,14 @@ struct parser_executor {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start, pos);
}
switch (ctx.input[pos]) {
case '"':
case '\'':
case '\\':
case '/':
case 'b':
case 'f':
case 'n':
case 'r':
case 't':
++pos;
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start, pos);
case 'u':
return handle_unicode_escape(ctx, start, pos);
default:
// Invalid escape sequence
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start);
char c = ctx.input[pos];
if (c == delimiter || c == '\\' || c == '/' || c == 'b' || c == 'f' || c == 'n' || c == 'r' || c == 't') {
++pos;
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start, pos);
} else if (c == 'u') {
return handle_unicode_escape(ctx, start, pos);
} else {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start);
}
}
@ -704,62 +695,20 @@ struct parser_executor {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start, pos);
}
common_peg_parse_result operator()(const common_peg_json_string_parser & /* p */) {
common_peg_parse_result operator()(const common_peg_string_parser & p) {
auto pos = start_pos;
// Parse string content (without quotes)
while (pos < ctx.input.size()) {
char c = ctx.input[pos];
if (c == '"') {
// Found closing quote - success (don't consume it)
if (c == p.delimiter) {
// Found closing delimiter - success (don't consume it)
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos);
}
if (c == '\\') {
auto result = handle_escape_sequence(ctx, start_pos, pos);
if (!result.success()) {
return result;
}
} else {
auto utf8_result = common_parse_utf8_codepoint(ctx.input, pos);
if (utf8_result.status == utf8_parse_result::INCOMPLETE) {
if (!ctx.is_lenient()) {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
}
if (utf8_result.status == utf8_parse_result::INVALID) {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
}
pos += utf8_result.bytes_consumed;
}
}
// Reached end without finding closing quote
if (!ctx.is_lenient()) {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos);
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
}
common_peg_parse_result operator()(const common_peg_python_dict_string_parser & /* p */) {
auto pos = start_pos;
// Parse string content (without quotes)
while (pos < ctx.input.size()) {
char c = ctx.input[pos];
if (c == '\'') {
// Found closing quote - success (don't consume it)
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos);
}
if (c == '\\') {
auto result = handle_escape_sequence(ctx, start_pos, pos);
auto result = handle_escape_sequence(ctx, start_pos, pos, p.delimiter);
if (!result.success()) {
return result;
}
@ -988,8 +937,7 @@ void common_peg_arena::resolve_refs() {
std::is_same_v<T, common_peg_ref_parser> ||
std::is_same_v<T, common_peg_until_parser> ||
std::is_same_v<T, common_peg_literal_parser> ||
std::is_same_v<T, common_peg_json_string_parser> ||
std::is_same_v<T, common_peg_python_dict_string_parser> ||
std::is_same_v<T, common_peg_string_parser> ||
std::is_same_v<T, common_peg_chars_parser> ||
std::is_same_v<T, common_peg_any_parser> ||
std::is_same_v<T, common_peg_space_parser>) {
@ -1065,10 +1013,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id
return "CharRepeat(" + p.pattern + ", " + std::to_string(p.min_count) + ", unbounded)";
}
return "CharRepeat(" + p.pattern + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")";
} else if constexpr (std::is_same_v<T, common_peg_json_string_parser>) {
return "JsonString()";
} else if constexpr (std::is_same_v<T, common_peg_python_dict_string_parser>) {
return "PythonDictString()";
} else if constexpr (std::is_same_v<T, common_peg_string_parser>) {
return "String(" + std::string(1, p.delimiter) + ")";
} else if constexpr (std::is_same_v<T, common_peg_until_parser>) {
return "Until(" + string_join(p.delimiters, " | ") + ")";
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
@ -1281,47 +1227,25 @@ common_peg_arena common_peg_parser_builder::build() {
// String primitives
common_peg_parser common_peg_parser_builder::json_string_content() {
return wrap(arena_.add_parser(common_peg_json_string_parser{}));
}
common_peg_parser common_peg_parser_builder::single_quoted_string_content() {
return wrap(arena_.add_parser(common_peg_python_dict_string_parser{}));
common_peg_parser common_peg_parser_builder::string_content(char delimiter) {
return wrap(arena_.add_parser(common_peg_string_parser{delimiter}));
}
common_peg_parser common_peg_parser_builder::double_quoted_string() {
return rule("dq-string",
[this]() { return sequence({ literal("\""), json_string_content(), literal("\""), space() }); });
}
common_peg_parser common_peg_parser_builder::single_quoted_string() {
return rule("sq-string",
[this]() { return sequence({ literal("'"), single_quoted_string_content(), literal("'"), space() }); });
}
common_peg_parser common_peg_parser_builder::flexible_string() {
return rule("flexible-string", [this]() { return choice({ double_quoted_string(), single_quoted_string() }); });
}
// Generic helpers for object/array structure
common_peg_parser common_peg_parser_builder::generic_object(const std::string & name,
const common_peg_parser & string_parser,
const common_peg_parser & value_parser) {
return rule(name, [this, string_parser, value_parser]() {
auto ws = space();
auto member = sequence({ string_parser, ws, literal(":"), ws, value_parser });
auto members = sequence({ member, zero_or_more(sequence({ ws, literal(","), ws, member })) });
return sequence({ literal("{"), ws, choice({ literal("}"), sequence({ members, ws, literal("}") }) }) });
return rule("double-quoted-string", [this]() {
return sequence({literal("\""), string_content('"'), literal("\""), space()});
});
}
common_peg_parser common_peg_parser_builder::generic_array(const std::string & name,
const common_peg_parser & value_parser) {
return rule(name, [this, value_parser]() {
auto ws = space();
auto elements = sequence({ value_parser, zero_or_more(sequence({ literal(","), ws, value_parser })) });
return sequence({ literal("["), ws, choice({ literal("]"), sequence({ elements, ws, literal("]") }) }) });
common_peg_parser common_peg_parser_builder::single_quoted_string() {
return rule("single-quoted-string", [this]() {
return sequence({literal("'"), string_content('\''), literal("'"), space()});
});
}
common_peg_parser common_peg_parser_builder::quoted_string() {
return rule("quoted-string", [this]() {
return choice({double_quoted_string(), single_quoted_string()});
});
}
@ -1344,7 +1268,7 @@ common_peg_parser common_peg_parser_builder::json_number() {
common_peg_parser common_peg_parser_builder::json_string() {
return rule("json-string", [this]() {
return sequence({literal("\""), json_string_content(), literal("\""), space()});
return sequence({literal("\""), string_content('"'), literal("\""), space()});
});
}
@ -1361,11 +1285,36 @@ common_peg_parser common_peg_parser_builder::json_null() {
}
common_peg_parser common_peg_parser_builder::json_object() {
return generic_object("json-object", json_string(), json());
return rule("json-object", [this]() {
auto ws = space();
auto member = sequence({json_string(), ws, literal(":"), ws, json()});
auto members = sequence({member, zero_or_more(sequence({ws, literal(","), ws, member}))});
return sequence({
literal("{"),
ws,
choice({
literal("}"),
sequence({members, ws, literal("}")})
}),
ws
});
});
}
common_peg_parser common_peg_parser_builder::json_array() {
return generic_array("json-array", json());
return rule("json-array", [this]() {
auto ws = space();
auto elements = sequence({json(), zero_or_more(sequence({literal(","), ws, json()}))});
return sequence({
literal("["),
ws,
choice({
literal("]"),
sequence({elements, ws, literal("]")})
}),
ws
});
});
}
common_peg_parser common_peg_parser_builder::json() {
@ -1382,7 +1331,9 @@ common_peg_parser common_peg_parser_builder::json() {
}
common_peg_parser common_peg_parser_builder::python_string() {
return rule("python-string", [this]() { return choice({ double_quoted_string(), single_quoted_string() }); });
return rule("python-string", [this]() {
return choice({double_quoted_string(), single_quoted_string()});
});
}
common_peg_parser common_peg_parser_builder::python_number() {
@ -1390,24 +1341,63 @@ common_peg_parser common_peg_parser_builder::python_number() {
}
common_peg_parser common_peg_parser_builder::python_bool() {
return rule("python-bool", [this]() { return sequence({ choice({ literal("True"), literal("False") }), space() }); });
return rule("python-bool", [this]() {
return sequence({
choice({literal("True"), literal("False")}),
space()
});
});
}
common_peg_parser common_peg_parser_builder::python_null() {
return rule("python-none", [this]() { return sequence({ literal("None"), space() }); });
return rule("python-none", [this]() {
return sequence({literal("None"), space()});
});
}
common_peg_parser common_peg_parser_builder::python_dict() {
return generic_object("python-dict", python_string(), python_value());
return rule("python-dict", [this]() {
auto ws = space();
auto member = sequence({python_string(), ws, literal(":"), ws, python_value()});
auto members = sequence({member, zero_or_more(sequence({ws, literal(","), ws, member}))});
return sequence({
literal("{"),
ws,
choice({
literal("}"),
sequence({members, ws, literal("}")})
}),
ws
});
});
}
common_peg_parser common_peg_parser_builder::python_array() {
return generic_array("python-array", python_value());
return rule("python-array", [this]() {
auto ws = space();
auto elements = sequence({python_value(), zero_or_more(sequence({literal(","), ws, python_value()}))});
return sequence({
literal("["),
ws,
choice({
literal("]"),
sequence({elements, ws, literal("]")})
}),
ws
});
});
}
common_peg_parser common_peg_parser_builder::python_value() {
return rule("python-value", [this]() {
return choice({ python_dict(), python_array(), python_string(), python_number(), python_bool(), python_null() });
return choice({
python_dict(),
python_array(),
python_string(),
python_number(),
python_bool(),
python_null()
});
});
}
@ -1528,8 +1518,7 @@ static std::unordered_set<std::string> collect_reachable_rules(
std::is_same_v<T, common_peg_chars_parser> ||
std::is_same_v<T, common_peg_space_parser> ||
std::is_same_v<T, common_peg_any_parser> ||
std::is_same_v<T, common_peg_json_string_parser> ||
std::is_same_v<T, common_peg_python_dict_string_parser>) {
std::is_same_v<T, common_peg_string_parser>) {
// These parsers do not have any children
} else if constexpr (std::is_same_v<T, common_peg_sequence_parser>) {
for (auto child : p.children) {
@ -1665,10 +1654,9 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
return result + "{" + std::to_string(p.min_count) + "}";
}
return result + "{" + std::to_string(p.min_count) + "," + std::to_string(p.max_count) + "}";
} else if constexpr (std::is_same_v<T, common_peg_json_string_parser>) {
return R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)";
} else if constexpr (std::is_same_v<T, common_peg_python_dict_string_parser>) {
return R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)";
} else if constexpr (std::is_same_v<T, common_peg_string_parser>) {
const std::string delim(1, p.delimiter);
return R"(( [^)" + delim + R"(\\] | "\\" ( [)" + delim + R"(\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)";
} else if constexpr (std::is_same_v<T, common_peg_until_parser>) {
if (p.delimiters.empty()) {
return ".*";
@ -1798,10 +1786,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
{"min_count", p.min_count},
{"max_count", p.max_count}
};
} else if constexpr (std::is_same_v<T, common_peg_json_string_parser>) {
return json{{"type", "json_string"}};
} else if constexpr (std::is_same_v<T, common_peg_python_dict_string_parser>) {
return json{{ "type", "python_dict_string" }};
} else if constexpr (std::is_same_v<T, common_peg_string_parser>) {
return json{{"type", "string"}, {"delimiter", std::string(1, p.delimiter)}};
} else if constexpr (std::is_same_v<T, common_peg_until_parser>) {
return json{{"type", "until"}, {"delimiters", p.delimiters}};
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
@ -1928,11 +1914,15 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
}
return parser;
}
if (type == "json_string") {
return common_peg_json_string_parser{};
}
if (type == "python_dict_string") {
return common_peg_python_dict_string_parser{};
if (type == "string") {
if (!j.contains("delimiter")) {
throw std::runtime_error("string parser missing delimiter field.");
}
std::string delimiter = j["delimiter"];
if (delimiter.empty()) {
throw std::runtime_error("string parser delimiter is empty.");
}
return common_peg_string_parser{delimiter[0]};
}
if (type == "until") {
if (!j.contains("delimiters") || !j["delimiters"].is_array()) {

View File

@ -231,8 +231,9 @@ struct common_peg_chars_parser {
int max_count; // -1 for unbounded
};
struct common_peg_json_string_parser {};
struct common_peg_python_dict_string_parser {};
struct common_peg_string_parser {
char delimiter;
};
struct common_peg_until_parser {
std::vector<std::string> delimiters;
@ -280,8 +281,7 @@ using common_peg_parser_variant = std::variant<
common_peg_any_parser,
common_peg_space_parser,
common_peg_chars_parser,
common_peg_json_string_parser,
common_peg_python_dict_string_parser,
common_peg_string_parser,
common_peg_until_parser,
common_peg_schema_parser,
common_peg_rule_parser,
@ -340,10 +340,6 @@ class common_peg_parser_builder {
common_peg_parser wrap(common_peg_parser_id id) { return common_peg_parser(id, *this); }
common_peg_parser add(const common_peg_parser_variant & p) { return wrap(arena_.add_parser(p)); }
// Generic helpers for building object/array structures with configurable string/value parsers.
common_peg_parser generic_object(const std::string & name, const common_peg_parser & string_parser, const common_peg_parser & value_parser);
common_peg_parser generic_array(const std::string & name, const common_peg_parser & value_parser);
public:
common_peg_parser_builder();
@ -444,13 +440,10 @@ class common_peg_parser_builder {
common_peg_parser single_quoted_string();
// Matches a string that accepts both double-quoted and single-quoted styles.
common_peg_parser flexible_string();
common_peg_parser quoted_string();
// Matches double-quoted string content without the surrounding quotes.
common_peg_parser json_string_content();
// Matches single-quoted string content without the surrounding quotes.
common_peg_parser single_quoted_string_content();
// Matches string content without the surrounding delimiter.
common_peg_parser string_content(char delimiter);
// Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null.
// value -> object | array | string | number | true | false | null

219
common/reasoning-budget.cpp Normal file
View File

@ -0,0 +1,219 @@
#include "reasoning-budget.h"
#include "common.h"
#include "unicode.h"
#include "log.h"
#include <cmath>
#include <cstdint>
#include <string>
#include <vector>
struct token_matcher {
std::vector<llama_token> tokens;
size_t pos = 0;
bool advance(llama_token token) {
if (tokens.empty()) {
return false;
}
if (token == tokens[pos]) {
pos++;
if (pos >= tokens.size()) {
pos = 0;
return true;
}
} else {
pos = 0;
if (token == tokens[0]) {
pos = 1;
}
}
return false;
}
void reset() { pos = 0; }
};
struct common_reasoning_budget_ctx {
const llama_vocab * vocab;
token_matcher start_matcher;
token_matcher end_matcher;
std::vector<llama_token> forced_tokens;
int32_t budget; // maximum tokens in reasoning block
int32_t remaining; // tokens remaining in budget
common_reasoning_budget_state state;
// for forcing
size_t force_pos; // next position in forced_tokens to force
};
static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) {
return "reasoning-budget";
}
static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) {
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
switch (ctx->state) {
case REASONING_BUDGET_IDLE:
{
if (ctx->start_matcher.advance(token)) {
ctx->state = REASONING_BUDGET_COUNTING;
ctx->remaining = ctx->budget;
LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget);
if (ctx->remaining <= 0) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
}
}
break;
}
case REASONING_BUDGET_COUNTING:
case REASONING_BUDGET_WAITING_UTF8:
{
if (ctx->end_matcher.advance(token)) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: deactivated (natural end)\n");
break;
}
bool utf8_complete = true;
if (ctx->vocab != nullptr) {
const std::string piece = common_token_to_piece(ctx->vocab, token, false);
utf8_complete = common_utf8_is_complete(piece);
}
if (ctx->state == REASONING_BUDGET_WAITING_UTF8) {
if (utf8_complete) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
}
} else if (ctx->state == REASONING_BUDGET_COUNTING) {
ctx->remaining--;
if (ctx->remaining <= 0) {
if (utf8_complete) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n");
} else {
ctx->state = REASONING_BUDGET_WAITING_UTF8;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n");
}
}
}
break;
}
case REASONING_BUDGET_FORCING:
// force_pos is advanced in apply(), not here.
// This ensures the first forced token isn't skipped when the sampler
// is initialized directly in FORCING state (e.g. COUNTING + budget=0)
break;
case REASONING_BUDGET_DONE:
break;
}
}
static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
if (ctx->state != REASONING_BUDGET_FORCING) {
// passthrough — don't modify logits
return;
}
if (ctx->force_pos >= ctx->forced_tokens.size()) {
return;
}
const llama_token forced = ctx->forced_tokens[ctx->force_pos];
// set all logits to -inf except the forced token
for (size_t i = 0; i < cur_p->size; i++) {
if (cur_p->data[i].id != forced) {
cur_p->data[i].logit = -INFINITY;
}
}
// advance to next forced token (done here rather than in accept so that
// the first forced token isn't skipped when starting in FORCING state)
ctx->force_pos++;
if (ctx->force_pos >= ctx->forced_tokens.size()) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: forced sequence complete, done\n");
}
}
static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
ctx->state = REASONING_BUDGET_IDLE;
ctx->remaining = ctx->budget;
ctx->start_matcher.reset();
ctx->end_matcher.reset();
ctx->force_pos = 0;
}
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
return common_reasoning_budget_init(
ctx->vocab,
ctx->start_matcher.tokens,
ctx->end_matcher.tokens,
ctx->forced_tokens,
ctx->budget,
ctx->state);
}
static void common_reasoning_budget_free(struct llama_sampler * smpl) {
delete (common_reasoning_budget_ctx *) smpl->ctx;
}
static struct llama_sampler_i common_reasoning_budget_i = {
/* .name = */ common_reasoning_budget_name,
/* .accept = */ common_reasoning_budget_accept,
/* .apply = */ common_reasoning_budget_apply,
/* .reset = */ common_reasoning_budget_reset,
/* .clone = */ common_reasoning_budget_clone,
/* .free = */ common_reasoning_budget_free,
/* .backend_init = */ nullptr,
/* .backend_accept = */ nullptr,
/* .backend_apply = */ nullptr,
/* .backend_set_input = */ nullptr,
};
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state) {
// promote COUNTING with budget <= 0 to FORCING
if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
initial_state = REASONING_BUDGET_FORCING;
}
return llama_sampler_init(
/* .iface = */ &common_reasoning_budget_i,
/* .ctx = */ new common_reasoning_budget_ctx {
/* .vocab = */ vocab,
/* .start_matcher = */ { start_tokens, 0 },
/* .end_matcher = */ { end_tokens, 0 },
/* .forced_tokens = */ forced_tokens,
/* .budget = */ budget,
/* .remaining = */ budget,
/* .state = */ initial_state,
/* .force_pos = */ 0,
}
);
}

41
common/reasoning-budget.h Normal file
View File

@ -0,0 +1,41 @@
#pragma once
#include "llama.h"
#include <cstdint>
#include <vector>
enum common_reasoning_budget_state {
REASONING_BUDGET_IDLE, // waiting for start sequence
REASONING_BUDGET_COUNTING, // counting down tokens
REASONING_BUDGET_FORCING, // forcing budget message + end sequence
REASONING_BUDGET_WAITING_UTF8, // budget exhausted, waiting for UTF-8 completion
REASONING_BUDGET_DONE, // passthrough forever
};
// Creates a reasoning budget sampler that limits token generation inside a
// reasoning block (e.g. between <think> and </think>).
//
// State machine: IDLE -> COUNTING -> WAITING_UTF8 -> FORCING -> DONE
// IDLE: passthrough, watching for start_tokens sequence
// COUNTING: counting down remaining tokens, watching for natural end_tokens
// WAITING_UTF8: budget exhausted, allowing tokens to complete a UTF-8 sequence
// FORCING: forces forced_tokens token-by-token (all other logits -> -inf)
// DONE: passthrough forever
//
// Parameters:
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
// start_tokens - token sequence that activates counting
// end_tokens - token sequence for natural deactivation
// forced_tokens - token sequence forced when budget expires
// budget - max tokens allowed in the reasoning block
// initial_state - initial state of the sampler (e.g. IDLE or COUNTING)
// note: COUNTING with budget <= 0 is promoted to FORCING
//
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state);

View File

@ -2,6 +2,7 @@
#include "common.h"
#include "log.h"
#include "reasoning-budget.h"
#include <algorithm>
#include <cmath>
@ -250,6 +251,17 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
}
// reasoning budget sampler — added first so it can force tokens before other samplers
if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
samplers.push_back(common_reasoning_budget_init(
vocab,
params.reasoning_budget_start,
params.reasoning_budget_end,
params.reasoning_budget_forced,
params.reasoning_budget_tokens,
params.reasoning_budget_activate_immediately ? REASONING_BUDGET_COUNTING : REASONING_BUDGET_IDLE));
}
if (params.has_logit_bias()) {
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
}

View File

@ -1,8 +1,10 @@
#include "unicode.h"
#include <algorithm>
#include <cassert>
#include <stdexcept>
#include <vector>
#include <string>
#include <vector>
// implementation adopted from src/unicode.cpp
@ -67,6 +69,20 @@ utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t off
return utf8_parse_result(utf8_parse_result::INVALID);
}
bool common_utf8_is_complete(const std::string & s) {
if (s.empty()) {
return true;
}
for (int i = 1; i <= std::min(4, (int)s.size()); i++) {
unsigned char c = s[s.size() - i];
if ((c & 0xC0) != 0x80) {
int expected = (c >= 0xF0) ? 4 : (c >= 0xE0) ? 3 : (c >= 0xC0) ? 2 : 1;
return i >= expected;
}
}
return false;
}
std::string common_unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
std::string result;
for (size_t i = 0; i < cps.size(); ++i) {

View File

@ -20,6 +20,9 @@ struct utf8_parse_result {
// Returns 0 for invalid first bytes
size_t common_utf8_sequence_length(unsigned char first_byte);
// Check if a string ends with a complete UTF-8 sequence.
bool common_utf8_is_complete(const std::string & s);
// Parse a single UTF-8 codepoint from input
utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t offset);

View File

@ -4390,15 +4390,31 @@ class Qwen3Model(Qwen2Model):
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
self.origin_hf_arch = hparams.get('architectures', [None])[0]
# a bit hacky, but currently the only way to detect if this is a rerank model
# ref: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
if self._is_qwen3_reranker():
self._find_rerank_config()
def _is_qwen3_reranker(self) -> bool:
readme_path = self.dir_model / "README.md"
readme_text = ""
if readme_path.exists():
with readme_path.open("r", encoding="utf-8") as f:
readme_text = f.read()
if "# Qwen3-Reranker" in readme_text:
self._find_rerank_config()
name_hints = [
str(self.dir_model.name),
str(self.hparams.get("_name_or_path", "")),
str(self.hparams.get("model_type", "")),
str(self.origin_hf_arch or ""),
]
name_hints = [hint.lower() for hint in name_hints if hint]
if "# qwen3-reranker" in readme_text.lower() or "# qwen3-vl-reranker" in readme_text.lower():
return True
if any("qwen3-reranker" in hint or "qwen3-vl-reranker" in hint for hint in name_hints):
return True
return "sequenceclassification" in (self.origin_hf_arch or "").lower()
def set_vocab(self):
# deal with intern-s1-mini

View File

@ -599,7 +599,13 @@ If KleidiAI is enabled, the output will contain a line similar to:
```
load_tensors: CPU_KLEIDIAI model buffer size = 3474.00 MiB
```
KleidiAI's microkernels implement optimized tensor operations using Arm CPU features such as dotprod, int8mm and SME. llama.cpp selects the most efficient kernel based on runtime CPU feature detection. However, on platforms that support SME, you must manually enable SME microkernels by setting the environment variable `GGML_KLEIDIAI_SME=1`.
KleidiAIs microkernels implement optimized tensor operations using Arm CPU features such as dotprod, int8mm, SVE, and SME. Llama.cpp selects the most efficient kernels at runtime based on detected CPU capabilities.
On CPUs that support SME, SME microkernels are enabled automatically using runtime detection.
The environment variable GGML_KLEIDIAI_SME can be used to control SME behavior:
- Not set: enable SME automatically if supported and detected.
- 0: disable SME.
- <n> > 0: enable SME and assume <n> available SME units (override auto detection).
If SME is not supported by the CPU, SME microkernels are always disabled.
Depending on your build target, other higher priority backends may be enabled by default. To ensure the CPU backend is used, you must disable the higher priority backends either at compile time, e.g. -DGGML_METAL=OFF, or during run-time using the command line option `--device none`.

View File

@ -23,7 +23,7 @@ Legend:
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | | 🟡 | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
@ -31,7 +31,7 @@ Legend:
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | | 🟡 | ✅ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
@ -47,7 +47,7 @@ Legend:
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
@ -64,7 +64,7 @@ Legend:
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | | ✅ | ✅ | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
@ -76,7 +76,7 @@ Legend:
| OUT_PROD | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | 🟡 |
| PAD | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| POOL_1D | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| POOL_1D | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
@ -86,7 +86,7 @@ Legend:
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ROPE | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
@ -97,13 +97,13 @@ Legend:
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | | 🟡 | ✅ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | | 🟡 | ✅ | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | | 🟡 | ✅ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -633,7 +633,7 @@ class SchemaConverter:
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
items = schema.get('items') or schema['prefixItems']
items = schema.get('items', schema.get('prefixItems'))
if isinstance(items, list):
return self._add_rule(
rule_name,

View File

@ -8,7 +8,12 @@ extern "C" {
#define RPC_PROTO_MAJOR_VERSION 3
#define RPC_PROTO_MINOR_VERSION 6
#define RPC_PROTO_PATCH_VERSION 0
#define RPC_PROTO_PATCH_VERSION 1
#ifdef __cplusplus
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION");
#endif
#define GGML_RPC_MAX_SERVERS 16
// backend API

View File

@ -202,8 +202,9 @@
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_quantize_mat_q8_K_4x1_generic ggml_quantize_mat_q8_K_4x1
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0

File diff suppressed because it is too large Load Diff

View File

@ -520,7 +520,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
/* .required_cpu = */ CPU_FEATURE_I8MM,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q4_0,
/* .op_type = */ GGML_TYPE_F32,
@ -631,7 +631,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
/* .required_cpu = */ CPU_FEATURE_I8MM,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q4_0,
/* .op_type = */ GGML_TYPE_F32,
@ -801,7 +801,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
/* .required_cpu = */ CPU_FEATURE_I8MM,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q8_0,
/* .op_type = */ GGML_TYPE_F32,

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -28,13 +28,17 @@ template <int K, int N> struct block {
// control size
static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<4,16> size/padding");
static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 16, "wrong block<8,16> size/padding");
using block_q4_0x4 = block<4, 4>;
using block_q4_0x8 = block<4, 8>;
using block_q4_0x16 = block<4, 16>;
using block_q8_0x4 = block<8, 4>;
using block_q8_0x8 = block<8, 8>;
using block_q8_0x16 = block<8, 16>;
struct block_q4_Kx8 {
ggml_half d[8]; // super-block scale for quantized scales
@ -44,7 +48,14 @@ struct block_q4_Kx8 {
};
static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
struct block_q4_Kx16 {
ggml_half d[16]; // super-block scale for quantized scales
ggml_half dmin[16]; // super-block scale for quantized mins
uint8_t scales[192]; // scales and mins, quantized with 6 bits
uint8_t qs[2048]; // 4--bit quants
};
static_assert(sizeof(block_q4_Kx16) == sizeof(ggml_half) * 32 + K_SCALE_SIZE * 16 + QK_K * 8, "wrong q4_K block size/padding");
struct block_q2_Kx8 {
ggml_half d[8]; // super-block scale for quantized scales
ggml_half dmin[8]; // super-block scale for quantized mins
@ -53,6 +64,13 @@ struct block_q2_Kx8 {
};
static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding");
struct block_q2_Kx16 {
ggml_half d[16]; // Super-block scale for quantized scales
ggml_half dmin[16]; // Super-block scale for quantized mins
uint8_t scales[256]; // Sub-block scales (16 cols * 16 sub-blocks)
uint8_t qs[1024]; // Data (16 cols * 64 bytes per block)
};
static_assert(sizeof(block_q2_Kx16) == sizeof(ggml_half) * 32 + QK_K + QK_K * 4, "wrong q2_K block size/padding");
struct block_q5_Kx8 {
ggml_half d[8]; // super-block scale for quantized scales
@ -97,6 +115,12 @@ struct block_iq4_nlx8 {
static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding");
struct block_iq4_nlx16 {
ggml_half d[16]; // deltas for 16 iq4_nl blocks
uint8_t qs[QK4_NL * 8]; // nibbles / quants for 16 iq4_nl blocks
};
static_assert(sizeof(block_iq4_nlx16) == 16 * sizeof(ggml_half) + QK4_NL * 8, "wrong iq4_nlx16 block size/padding");
struct block_mxfp4x4 {
uint8_t e[4];
uint8_t qs[QK_MXFP4 * 2];
@ -109,7 +133,6 @@ struct block_mxfp4x8 {
};
static_assert(sizeof(block_mxfp4x8) == 8 + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding");
#if defined(__cplusplus)
extern "C" {
#endif
@ -132,6 +155,8 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -146,10 +171,22 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
#if defined __riscv_zvfh
void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
#endif
// Native implementations
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
@ -170,6 +207,8 @@ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -184,10 +223,22 @@ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
#if defined __riscv_zvfh
void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
#endif
#if defined(__cplusplus)
} // extern "C"

View File

@ -2,28 +2,29 @@
#include "ggml-cuda/common.cuh"
template <int S_v, bool KDA>
__global__ void gated_delta_net_cuda(const float * q,
const float * k,
const float * v,
const float * g,
const float * beta,
const float * curr_state,
float * dst,
int64_t H,
int64_t n_tokens,
int64_t n_seqs,
int64_t sq1,
int64_t sq2,
int64_t sq3,
int64_t sv1,
int64_t sv2,
int64_t sv3,
int64_t sb1,
int64_t sb2,
int64_t sb3,
int64_t rq1,
int64_t rq3,
float scale) {
__global__ void __launch_bounds__(S_v, 1)
gated_delta_net_cuda(const float * q,
const float * k,
const float * v,
const float * g,
const float * beta,
const float * curr_state,
float * dst,
const int64_t H,
const int64_t n_tokens,
const int64_t n_seqs,
const int64_t sq1,
const int64_t sq2,
const int64_t sq3,
const int64_t sv1,
const int64_t sv2,
const int64_t sv3,
const int64_t sb1,
const int64_t sb2,
const int64_t sb3,
const int64_t rq1,
const int64_t rq3,
const float scale) {
const int64_t h_idx = blockIdx.x;
const int64_t sequence = blockIdx.y;
const int col = threadIdx.x; // each thread owns one column
@ -40,8 +41,14 @@ __global__ void gated_delta_net_cuda(const float * q,
curr_state += state_offset;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
// Load state column into registers
// GCN and CDNA devices spill registers, we use shared mem for them. See https://github.com/ggml-org/llama.cpp/pull/20282#issuecomment-4025770229
// TODO: check optimal path for RDNA1 and RDNA2 devices.
#if (defined(GGML_USE_HIP) && !defined(RDNA3) && !defined(RDNA4)) || defined(GGML_USE_MUSA)
extern __shared__ float s_shared[];
float * s = s_shared + col * S_v;
#else
float s[S_v];
#endif
#pragma unroll
for (int i = 0; i < S_v; i++) {
s[i] = curr_state[i * S_v + col];
@ -114,6 +121,15 @@ __global__ void gated_delta_net_cuda(const float * q,
}
}
static size_t calculate_smem(const int sv, int cc)
{
size_t smem = 0;
if ((GGML_CUDA_CC_IS_AMD(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_RDNA4(cc)) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
smem = sv * sv * sizeof(float);
}
return smem;
}
template <bool KDA>
static void launch_gated_delta_net(
const float * q_d, const float * k_d, const float * v_d,
@ -129,25 +145,36 @@ static void launch_gated_delta_net(
dim3 grid_dims(H, n_seqs, 1);
dim3 block_dims(S_v, 1, 1);
int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
switch (S_v) {
case 32:
gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
case 32: {
constexpr int sv = 32;
size_t smem = calculate_smem(sv, cc);
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
break;
case 64:
gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
}
case 64: {
constexpr int sv = 64;
size_t smem = calculate_smem(sv, cc);
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
break;
case 128:
gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
}
case 128: {
constexpr int sv = 128;
size_t smem = calculate_smem(sv, cc);
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
break;
}
default:
GGML_ABORT("fatal error");
break;

View File

@ -76,7 +76,7 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
int row = tid / load_cols;
int col = tid % load_cols;
#pragma unroll
for (int idx = tid; idx < total_elems; idx += split_d_inner) {
for (int idx = 0; idx < total_elems; idx += split_d_inner) {
if (row < (int)split_d_inner) {
smem[row * n_cols + col] = x_block[row * stride_x + col];
}
@ -84,6 +84,9 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
col += split_d_inner;
row += col / load_cols;
col = col % load_cols;
if (idx >= total_elems - tid - split_d_inner) {
break;
}
}
__syncthreads();

View File

@ -75,6 +75,10 @@ struct ggml_metal {
// abort ggml_metal_graph_compute if callback returns true
ggml_abort_callback abort_callback;
void * abort_callback_data;
// error state - set when a command buffer fails during synchronize
// once set, graph_compute will return GGML_STATUS_FAILED until the backend is recreated
bool has_error;
};
ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
@ -158,6 +162,8 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
res->capture_started = false;
res->capture_scope = nil;
res->has_error = false;
res->gf = nil;
res->encode_async = nil;
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
@ -246,7 +252,8 @@ void ggml_metal_synchronize(ggml_metal_t ctx) {
if (status == MTLCommandBufferStatusError) {
GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
}
GGML_ABORT("fatal error");
ctx->has_error = true;
return;
}
}
}
@ -262,7 +269,15 @@ void ggml_metal_synchronize(ggml_metal_t ctx) {
if (status == MTLCommandBufferStatusError) {
GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
}
GGML_ABORT("fatal error");
// release this and all remaining command buffers before returning
for (size_t j = i; j < ctx->cmd_bufs_ext.count; ++j) {
[ctx->cmd_bufs_ext[j] release];
}
[ctx->cmd_bufs_ext removeAllObjects];
ctx->has_error = true;
return;
}
[cmd_buf release];
@ -414,6 +429,11 @@ bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, con
}
enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) {
if (ctx->has_error) {
GGML_LOG_ERROR("%s: backend is in error state from a previous command buffer failure - recreate the backend to recover\n", __func__);
return GGML_STATUS_FAILED;
}
// number of nodes encoded by the main thread (empirically determined)
const int n_main = MAX(64, 0.1*gf->n_nodes);

View File

@ -1717,12 +1717,29 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_met
char base[256];
char name[256];
snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);
if (mode == GGML_SCALE_MODE_BILINEAR) {
snprintf(base, 256, "kernel_upscale_bilinear_%s", ggml_type_name(op->src[0]->type));
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
snprintf(base, 256, "kernel_upscale_bicubic_%s", ggml_type_name(op->src[0]->type));
} else {
snprintf(base, 256, "kernel_upscale_nearest_%s", ggml_type_name(op->src[0]->type));
}
snprintf(name, 256, "%s_aa=%d", base, antialias);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_bool(cv, antialias, FC_UPSCALE + 0);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
return res;

View File

@ -1108,7 +1108,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
op->type == GGML_TYPE_F32 &&
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_POOL_1D:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_POOL_2D:

View File

@ -83,6 +83,7 @@
#define FC_UNARY 1200
#define FC_BIN 1300
#define FC_SUM_ROWS 1400
#define FC_UPSCALE 1500
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPSG 8
@ -890,6 +891,7 @@ typedef struct {
float sf1;
float sf2;
float sf3;
float poffs;
} ggml_metal_kargs_upscale;
typedef struct {

View File

@ -1963,6 +1963,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
(
op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
op->src[0]->type == GGML_TYPE_F16 ||
op->src[0]->type == GGML_TYPE_BF16 ||
op->src[0]->type == GGML_TYPE_Q4_0 ||
op->src[0]->type == GGML_TYPE_Q4_1 ||
op->src[0]->type == GGML_TYPE_Q5_0 ||
@ -1977,6 +1978,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
op->src[0]->type == GGML_TYPE_Q4_K ||
op->src[0]->type == GGML_TYPE_Q5_K ||
op->src[0]->type == GGML_TYPE_Q6_K ||
op->src[0]->type == GGML_TYPE_Q2_K ||
op->src[0]->type == GGML_TYPE_Q3_K ||
false) && (ne11 >= 4 && ne11 <= 8)
)
)
@ -3729,32 +3732,43 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const float sf0 = (float)ne0/op->src[0]->ne[0];
const float sf1 = (float)ne1/op->src[0]->ne[1];
const float sf2 = (float)ne2/op->src[0]->ne[2];
const float sf3 = (float)ne3/op->src[0]->ne[3];
float sf0 = (float)ne0/op->src[0]->ne[0];
float sf1 = (float)ne1/op->src[0]->ne[1];
float sf2 = (float)ne2/op->src[0]->ne[2];
float sf3 = (float)ne3/op->src[0]->ne[3];
const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
float poffs = 0.5f;
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
poffs = 0.0f;
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
}
ggml_metal_kargs_upscale args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.sf0 =*/ sf0,
/*.sf1 =*/ sf1,
/*.sf2 =*/ sf2,
/*.sf3 =*/ sf3
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.sf0 =*/ sf0,
/*.sf1 =*/ sf1,
/*.sf2 =*/ sf2,
/*.sf3 =*/ sf3,
/*.poffs =*/ poffs,
};
auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);

View File

@ -3481,6 +3481,13 @@ template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, bfloat4, 4, dequantize_bf16_t4>;
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, bfloat4, 4, dequantize_bf16_t4>;
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, bfloat4, 4, dequantize_bf16_t4>;
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>;
#endif
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
@ -3531,6 +3538,16 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q2_K, 256, dequantize_q2_K>;
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q2_K, 256, dequantize_q2_K>;
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q2_K, 256, dequantize_q2_K>;
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q2_K, 256, dequantize_q2_K>;
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q3_K, 256, dequantize_q3_K>;
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q3_K, 256, dequantize_q3_K>;
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q3_K, 256, dequantize_q3_K>;
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q3_K, 256, dequantize_q3_K>;
template<typename T0, typename T1, short NR0, typename args_t>
void kernel_mul_mv_t_t_impl(
args_t args,
@ -4530,7 +4547,9 @@ kernel void kernel_conv_transpose_2d<half>(
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
kernel void kernel_upscale_f32(
constant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]];
kernel void kernel_upscale_nearest_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
@ -4556,6 +4575,156 @@ kernel void kernel_upscale_f32(
}
}
static inline float bilinear_tri(float x) {
return MAX(0.0f, 1.0f - fabs(x));
}
kernel void kernel_upscale_bilinear_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3 / args.sf3;
const int64_t i02 = i2 / args.sf2;
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
const int64_t i01 = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01)));
const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1));
const float fd1 = MAX(0.0f, MIN(1.0f, f01 - (float)i01));
src0 += i03*args.nb03 + i02*args.nb02;
device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
if (FC_upscale_aa) {
const float support0 = MAX(1.0f, 1.0f / args.sf0);
const float invscale0 = 1.0f / support0;
const float support1 = MAX(1.0f, 1.0f / args.sf1);
const float invscale1 = 1.0f / support1;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs));
int64_t x_max = MIN(args.ne00, (int64_t)ceil (f00 + support0 + args.poffs));
int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs));
int64_t y_max = MIN(args.ne01, (int64_t)ceil (f01 + support1 + args.poffs));
float sum = 0.0f;
float wsum = 0.0f;
for (int64_t sy = y_min; sy < y_max; ++sy) {
const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1);
for (int64_t sx = x_min; sx < x_max; ++sx) {
const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
const float w = wx * wy;
const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
sum += (*src_ptr) * w;
wsum += w;
}
}
const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f;
dst_ptr[i0] = v;
}
} else {
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
const int64_t i00 = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00)));
const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1));
const float fd0 = MAX(0.0f, MIN(1.0f, f00 - (float)i00));
device const float * src00 = (device const float *)(src0 + i01*args.nb01 + i00*args.nb00);
device const float * src10 = (device const float *)(src0 + i01*args.nb01 + i00p*args.nb00);
device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00);
device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00);
const float v =
(*src00) * (1.0f - fd0) * (1.0f - fd1) +
(*src10) * fd0 * (1.0f - fd1) +
(*src01) * (1.0f - fd0) * fd1 +
(*src11) * fd0 * fd1;
dst_ptr[i0] = v;
}
}
}
static inline float bicubic_weight1(float x) {
const float a = -0.75f;
return ((a + 2) * x - (a + 3)) * x * x + 1;
}
static inline float bicubic_weight2(float x) {
const float a = -0.75f;
return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
}
kernel void kernel_upscale_bicubic_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3 / args.sf3;
const int64_t i02 = i2 / args.sf2;
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
const int64_t i01 = (int64_t)floor(f01);
const float fd1 = f01 - (float)i01;
const float w_y0 = bicubic_weight2(fd1 + 1.0f);
const float w_y1 = bicubic_weight1(fd1);
const float w_y2 = bicubic_weight1(1.0f - fd1);
const float w_y3 = bicubic_weight2(2.0f - fd1);
const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02;
device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
const int64_t i00 = (int64_t)floor(f00);
const float fd0 = f00 - (float)i00;
const float w_x0 = bicubic_weight2(fd0 + 1.0f);
const float w_x1 = bicubic_weight1(fd0);
const float w_x2 = bicubic_weight1(1.0f - fd0);
const float w_x3 = bicubic_weight2(2.0f - fd0);
float sum = 0.0f;
for (int dy = -1; dy <= 2; ++dy) {
const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy));
const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3;
for (int dx = -1; dx <= 2; ++dx) {
const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
sum += (*src_ptr) * wx * wy;
}
}
dst_ptr[i0] = sum;
}
}
kernel void kernel_pad_f32(
constant ggml_metal_kargs_pad & args,
device const char * src0,

View File

@ -874,4 +874,95 @@ static bool fast_fp16_available(const int cc) {
return true; //Intel GPUs always support FP16.
}
enum class block_reduce_method {
MAX,
SUM,
};
template<block_reduce_method method_t, typename T, int warp_size>
struct block_reduce_policy;
template <typename T, typename... Ts>
inline constexpr bool is_any = (std::is_same_v<T, Ts> || ...);
template<typename...>
inline constexpr bool ggml_sycl_dependent_false_v = false;
#define WARP_32_SIZE 32
template <typename T, int warp_size> struct block_reduce_policy<block_reduce_method::SUM, T, warp_size> {
static T reduce(T val) {
if constexpr (is_any<T, float, sycl::float2, sycl::half2, int>) {
return warp_reduce_sum<warp_size>(val);
} else {
static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce sum");
}
}
static T sentinel() {
if constexpr (std::is_same_v<T, float>) {
return 0.0f;
} else if constexpr (std::is_same_v<T, sycl::float2>) {
return sycl::float2(0.0f, 0.0f);
} else if constexpr (std::is_same_v<T, sycl::half2>) {
return sycl::half2(0.0f, 0.0f);
} else if constexpr (std::is_same_v<T, int>) {
return 0;
} else {
static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce sum");
}
}
};
template <typename T, int warp_size> struct block_reduce_policy<block_reduce_method::MAX, T, warp_size> {
static T reduce(T val) {
if constexpr (is_any<T, float, sycl::half2>) {
return warp_reduce_max<warp_size>(val);
} else {
static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce max");
}
}
static T sentinel() {
if constexpr (std::is_same_v<T, float>) {
return -INFINITY;
} else if constexpr (std::is_same_v<T, sycl::half2>) {
return sycl::half2(-INFINITY, -INFINITY);
} else {
static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce max");
}
}
};
template <block_reduce_method reduce_method_t, int warp_size, typename T>
static T block_reduce(T val, T * shared_vals, int block_size_template) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
val = block_reduce_policy<reduce_method_t, T,warp_size>::reduce(val);
const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
const int nthreads = item_ct1.get_local_range(2);
const int nwarps = nthreads / WARP_SIZE;
if (block_size > warp_size) {
assert((block_size <= 1024) && (block_size % warp_size) == 0);
const int warp_id = item_ct1.get_local_id(2) / warp_size;
const int lane_id = item_ct1.get_local_id(2) % warp_size;
if (lane_id == 0) {
shared_vals[warp_id] = val;
}
item_ct1.barrier(sycl::access::fence_space::local_space);
size_t nreduce = nwarps / WARP_SIZE;
float tmp = 0.f;
if (lane_id < (static_cast<int>(block_size) / warp_size)) {
for (size_t i = 0; i < nreduce; i += 1)
{
tmp += shared_vals[lane_id + i * WARP_SIZE];
}
}
return block_reduce_policy<reduce_method_t, T, warp_size>::reduce(tmp);
}
return val;
}
#endif // GGML_SYCL_COMMON_HPP

View File

@ -39,6 +39,11 @@ template<typename dst_t, typename src_t>
return sycl::ext::oneapi::bfloat16(float(x));
} else if constexpr (std::is_same_v<src_t, sycl::ext::oneapi::bfloat16>) {
return static_cast<float>(x);
} else if constexpr (std::is_same_v<src_t, sycl::float2> && std::is_same_v<dst_t, sycl::half2>) {
return x.template convert<sycl::half, sycl::rounding_mode::rte>();
} else if constexpr (std::is_same_v<src_t, sycl::float2> &&
std::is_same_v<dst_t, sycl::vec<sycl::ext::oneapi::bfloat16, 2>>) {
return {x.x, x.y};
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
return int32_t(x);
} else {
@ -46,4 +51,5 @@ template<typename dst_t, typename src_t>
}
}
#endif // GGML_SYCL_CONVERT_HPP

View File

@ -9,23 +9,32 @@
#define SYCL_LOCAL_ID_CALC(ITEM, IDX) \
(ITEM.get_local_range(IDX) * ITEM.get_group(IDX) + ITEM.get_local_id(IDX))
static void acc_f32(const float * x, const float * y, float * dst, const int64_t ne,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int64_t i = SYCL_LOCAL_ID_CALC(item_ct1, 2);
static void acc_f32(const float * x, const float * y, float * dst, const int ne,
const int ne10, const int ne11, const int ne12,
const int nb1, const int nb2, int offset, const sycl::nd_item<1> &item_ct1) {
const int i = SYCL_LOCAL_ID_CALC(item_ct1, 0);
if (i >= ne) {
return;
}
int src1_idx = i - offset;
int oz = src1_idx / nb2;
int oy = (src1_idx - (oz * nb2)) / nb1;
int ox = src1_idx % nb1;
if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
} else {
dst[i] = x[i];
int64_t src1_idx = i - offset;
int64_t tmp = src1_idx;
const int64_t i13 = tmp / s13;
tmp -= i13 * s13;
const int64_t i12 = tmp / s12;
tmp -= i12 * s12;
const int64_t i11 = tmp / s11;
tmp -= i11 * s11;
const int64_t i10 = tmp;
float val = x[i];
if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) {
val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10];
}
dst[i] = val;
}
/* Unary OP funcs */
@ -364,18 +373,15 @@ static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const
namespace ggml_sycl_detail {
static void acc_f32_sycl(const float *x, const float *y, float *dst,
const int n_elements, const int ne10, const int ne11,
const int ne12, const int nb1, const int nb2,
const int offset, queue_ptr stream) {
int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) *
sycl::range<1>(SYCL_ACC_BLOCK_SIZE),
sycl::range<1>(SYCL_ACC_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
item_ct1);
});
const int64_t n_elements, const int64_t ne10, const int64_t ne11,
const int64_t ne12, const int64_t ne13, const int64_t s1, const int64_t s2, const int64_t s3,
const int64_t offset, queue_ptr stream) {
const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
});
}
template<typename T>
@ -402,25 +408,19 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
template<typename KernelInvoker, typename... Args>
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
#if defined (GGML_SYCL_F16)
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
#else
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
#endif
GGML_ASSERT(dst->src[0]->type == dst->type);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
switch (dst->type) {
#if defined (GGML_SYCL_F16)
case GGML_TYPE_F16:
{
auto data_pts = cast_data<sycl::half>(dst);
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
break;
}
#endif
case GGML_TYPE_F32:
{
auto data_pts = cast_data<float>(dst);
@ -434,14 +434,10 @@ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx,
template<typename KernelInvoker, typename... Args>
static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
#if defined (GGML_SYCL_F16)
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
#else
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
#endif
GGML_ASSERT(dst->src[0]->type == dst->type);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const ggml_tensor * src0 = dst->src[0];
@ -463,7 +459,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c
GGML_ASSERT(src0->type == src1->type);
}
switch (dst->type) {
#if defined (GGML_SYCL_F16)
case GGML_TYPE_F16:
{
sycl::half * src0_p = (sycl::half *) src0_d;
@ -484,7 +479,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c
std::forward<Args>(args)...);
break;
}
#endif
case GGML_TYPE_F32:
{
float * src0_p = (float *) src0_d;
@ -513,13 +507,9 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c
template<typename KernelInvoker, typename... Args>
static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
#if defined (GGML_SYCL_F16)
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
#else
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
#endif
GGML_ASSERT(dst->src[0]->type == dst->type);
dpct::queue_ptr main_stream = ctx.stream();
@ -530,7 +520,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2];
const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3];
switch (dst->type) {
#if defined (GGML_SYCL_F16)
case GGML_TYPE_F16:
{
auto data_pts = cast_data<sycl::half>(dst);
@ -539,7 +528,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
main_stream, std::forward<Args>(args)...);
break;
}
#endif
case GGML_TYPE_F32:
{
auto data_pts = cast_data<float>(dst);
@ -868,22 +856,31 @@ static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tens
}
static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
dpct::queue_ptr stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
const float * src1_dd = static_cast<const float*>(dst->src[1]->data);
float * dst_dd = static_cast<float *>(dst->data);
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
int offset = dst->op_params[3] / 4; // offset in bytes
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(dst->nb[0] == ggml_element_size(dst));
GGML_ASSERT(ggml_is_contiguously_allocated(dst));
ggml_sycl_detail::acc_f32_sycl(src0_dd, src1_dd, dst_dd, (int)ggml_nelements(dst), (int)dst->src[1]->ne[0], (int)dst->src[1]->ne[1], (int)dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
const int64_t s1 = dst->op_params[0] / sizeof(float);
const int64_t s2 = dst->op_params[1] / sizeof(float);
const int64_t s3 = dst->op_params[2] / sizeof(float);
const int64_t offset = dst->op_params[3] / sizeof(float);
ggml_sycl_detail::acc_f32_sycl(src0_d, src1_d, dst_d, ggml_nelements(dst),
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
s1, s2, s3, offset, stream);
}
static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {

View File

@ -4145,6 +4145,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_ROPE:
ggml_sycl_rope(ctx, dst);
break;
case GGML_OP_ROPE_BACK:
ggml_sycl_rope_back(ctx, dst);
break;
case GGML_OP_IM2COL:
ggml_sycl_im2col(ctx, dst);
break;
@ -4851,6 +4854,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return max_bias == 0.0f;
}
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
case GGML_OP_IM2COL:
return true;
case GGML_OP_UPSCALE:
@ -4872,8 +4876,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
k > 0 && k <= 32;
}
case GGML_OP_POOL_2D:
case GGML_OP_ACC:
return true;
case GGML_OP_ACC:
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
case GGML_OP_PAD:
// TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985
if (ggml_get_op_params_i32(op, 8) != 0) {

View File

@ -202,47 +202,34 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
}
}
static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1);
const int tid = item_ct1.get_local_id(2);
const int nthreads = item_ct1.get_local_range(2);
const int nwarps = nthreads / WARP_SIZE;
template<int warp_size>
static void l2_norm_f32(const float * x, float * dst, const int ncols,
const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps,
const sycl::nd_item<3>& item_ct1, float* s_sum, const int block_size) {
const int nrows = item_ct1.get_group_range(2);
const int nchannels = item_ct1.get_group_range(1);
const int row = item_ct1.get_group(2);
const int channel = item_ct1.get_group(1);
const int sample = item_ct1.get_group(0);
const int tid = item_ct1.get_local_id(2);
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row * ncols + col];
const float xi = x[col];
tmp += xi * xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp, item_ct1);
if (block_size > WARP_SIZE) {
int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
/*
DPCT1118:3: SYCL group functions and algorithms must be encountered in
converged control flow. You may need to adjust the code.
*/
item_ct1.barrier(sycl::access::fence_space::local_space);
size_t nreduce = nwarps / WARP_SIZE;
tmp = 0.f;
for (size_t i = 0; i < nreduce; i += 1)
{
tmp += s_sum[lane_id + i * WARP_SIZE];
}
tmp = warp_reduce_sum(tmp, item_ct1);
}
const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
tmp = block_reduce<block_reduce_method::SUM, warp_size>(tmp, s_sum, block_size);
const float scale = sycl::rsqrt(sycl::fmax(tmp, eps * eps));
for (int col = tid; col < ncols; col += block_size) {
dst[row * ncols + col] = scale * x[row * ncols + col];
dst[col] = scale * x[col];
}
}
@ -369,42 +356,50 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
}
}
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
const int nrows, const float eps,
queue_ptr stream, int device) {
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
template<int warp_size>
static void l2_norm_f32_sycl(const float * x,
float * dst,
const int ncols,
const int nrows,
const int nchannels,
const int nsamples,
const int64_t stride_row,
const int64_t stride_channel,
const int64_t stride_sample,
const float eps,
queue_ptr stream,
int device) {
const dpct::dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
const dpct::dim3 block_dims(warp_size, 1, 1);
stream->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
sycl::nd_range<3>(blocks_num * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1,
nullptr, WARP_SIZE);
[[sycl::reqd_sub_group_size(warp_size)]] {
l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
nullptr, warp_size);
});
});
}
else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
assert(work_group_size % (warp_size * warp_size) == 0);
const sycl::range<3> block_dims(1, 1, work_group_size);
/*
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
int lsm_size = block_dims[2] > warp_size ? work_group_size / warp_size * sizeof(float): 0;
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(lsm_size),
cgh);
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
sycl::nd_range<3>(blocks_num * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size);
[[sycl::reqd_sub_group_size(warp_size)]] {
l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample,
eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
});
});
}
@ -634,21 +629,28 @@ void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * d
}
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
dpct::queue_ptr stream = ctx.stream();
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const int64_t ne00 = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
GGML_TENSOR_UNARY_OP_LOCALS;
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(nb00 == ts0);
const int64_t s01 = nb01 / ts0;
const int64_t s02 = nb02 / ts0;
const int64_t s03 = nb03 / ts0;
/*support both WARP_SIZE or WARP_32_SIZE in code
choose by hardware for better performance
*/
l2_norm_f32_sycl<WARP_SIZE>(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream, ctx.device);
}

View File

@ -1,4 +1,5 @@
#include "rope.hpp"
#include "convert.hpp"
#include "ggml-sycl/common.hpp"
#include "ggml.h"
@ -15,366 +16,489 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
}
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static void rope_yarn(
float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
float * cos_theta, float * sin_theta) {
// Get n-d rotational scaling corrected for extrapolation
template <bool forward>
static void rope_yarn(const float theta_extrap, const float freq_scale,
const rope_corr_dims corr_dims, const int64_t i0,
const float ext_factor, float mscale, float &cos_theta,
float &sin_theta) {
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
float ramp_mix =
rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
}
*cos_theta = sycl::cos(theta) * mscale;
*sin_theta = sycl::sin(theta) * mscale;
cos_theta = sycl::cos(theta) * mscale;
sin_theta = sycl::sin(theta) * mscale;
if (!forward) {
sin_theta *= -1.0f;
}
}
template <typename T, bool has_ff>
static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
const int32_t * pos, float freq_scale, float ext_factor, float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
const sycl::nd_item<3> & item_ct1) {
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
template <bool forward, bool has_ff, typename T, typename D>
static void rope_norm(const T *x, D *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02,
const int s03, const int s1, const int s2, const int s3,
const int n_dims, const int32_t *pos,
const float freq_scale, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float *freq_factors,
const int64_t *row_indices, const int set_rows_stride) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));
if (i0 >= ne0) {
if (i0 >= ne00) {
return;
}
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
const int row0 = row % ne1;
const int channel0 = row / ne1;
const uint32_t i3 = row_dst / (ne01 * ne02);
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
const int i = row * ne0 + i0;
const int i2 = channel0 * s2 + row0 * s1 + i0;
int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3;
const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03;
if (set_rows_stride != 0) {
idst = i1 * s1 + i0;
idst += row_indices[i2] * set_rows_stride;
}
const auto &store_coaelsced = [&](float x0, float x1) {
if constexpr (std::is_same_v<float, D>) {
sycl::float2 v = sycl::float2(x0, x1);
ggml_sycl_memcpy_1<8>(dst + idst, &v);
} else if constexpr (std::is_same_v<sycl::half, D>) {
sycl::half2 v = sycl::half2(x0, x1);
ggml_sycl_memcpy_1<4>(dst + idst, &v);
}
};
if (i0 >= n_dims) {
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
store_coaelsced(x[ix + 0], x[ix + 1]);
return;
}
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[i2 + 0];
const float x1 = x[i2 + 1];
const float x0 = x[ix + 0];
const float x1 = x[ix + 1];
dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
dst[i + 1] = x0 * sin_theta + x1 * cos_theta;
store_coaelsced(x0 * cos_theta - x1 * sin_theta,
x0 * sin_theta + x1 * cos_theta);
}
template <typename T, bool has_ff>
static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
const sycl::nd_item<3> & item_ct1) {
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
template <bool forward, bool has_ff, typename T, typename D>
static void rope_neox(const T *x, D *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02,
const int s03, const int s1, const int s2, const int s3,
const int n_dims, const int32_t *pos,
const float freq_scale, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float *freq_factors,
const int64_t *row_indices, const int set_rows_stride) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));
if (i0 >= ne0) {
if (i0 >= ne00) {
return;
}
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
const int row0 = row % ne1;
const int channel0 = row / ne1;
const uint32_t i3 = row_dst / (ne01 * ne02);
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
const int i = row * ne0 + i0 / 2;
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
if (set_rows_stride != 0) {
idst = i1 * s1 + i0 / 2;
idst += row_indices[i2] * set_rows_stride;
}
if (i0 >= n_dims) {
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
dst[idst + i0 / 2 + 0] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 0]);
dst[idst + i0 / 2 + 1] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 1]);
return;
}
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[i2 + 0];
const float x1 = x[i2 + n_dims / 2];
const float x0 = x[ix + 0];
const float x1 = x[ix + n_dims / 2];
dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
dst[idst + 0] = ggml_sycl_cast<D>(x0 * cos_theta - x1 * sin_theta);
dst[idst + n_dims / 2] = ggml_sycl_cast<D>(x0 * sin_theta + x1 * cos_theta);
}
template <typename T, bool has_ff>
static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float * freq_factors, const mrope_sections sections,
const bool is_imrope, const sycl::nd_item<3> & item_ct1) {
// get index pos
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
if (i0 >= ne0) {
template <bool forward, bool has_ff, typename T>
static void rope_multi(const T *x, T *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02,
const int s03, const int s1, const int s2, const int s3,
const int n_dims, const int32_t *pos,
const float freq_scale, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float *freq_factors,
const mrope_sections sections, const bool is_imrope) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));
if (i0 >= ne00) {
return;
}
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const int idst = (row_dst * ne0) + (i0 / 2);
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
const uint32_t i3 = row_dst / (ne01 * ne02);
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
if (i0 >= n_dims) {
*reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
dst[idst + i0 / 2 + 0] = x[ix + i0 / 2 + 0];
dst[idst + i0 / 2 + 1] = x[ix + i0 / 2 + 1];
return;
}
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
const int sect_dims =
sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (is_imrope) {
if (sector % 3 == 1 && sector < 3 * sections.v[1]) {
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
} else {
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
}
} else {
if (sector < sections.v[0]) {
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
} else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
} else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
} else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
}
}
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[ix + 0];
const float x1 = x[ix + n_dims/2];
// store results in dst
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta;
float cos_theta;
float sin_theta;
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[ix + 0];
const float x1 = x[ix + n_dims / 2];
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
dst[idst + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
}
template <bool forward, bool has_ff, typename T>
static void rope_vision(const T *x, T *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02,
const int s03, const int s1, const int s2, const int s3,
const int n_dims, const int32_t *pos,
const float freq_scale, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float *freq_factors,
const mrope_sections sections) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));
template <typename T, bool has_ff>
static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float * freq_factors, const mrope_sections sections,
const sycl::nd_item<3> & item_ct1) {
// get index pos
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
if (i0 >= ne0) {
if (i0 >= ne00) {
return;
}
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const int idst = (row_dst * ne0) + (i0 / 2);
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
const uint32_t i3 = row_dst / (ne01 * ne02);
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
const int sect_dims = sections.v[0] + sections.v[1];
const int sector = (i0 / 2) % sect_dims;
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0f;
float theta_base = 0.0;
if (sector < sections.v[0]) {
const int p = sector;
theta_base = pos[channel_x] * sycl::pow(theta_scale, (float) p);
} else {
theta_base = pos[i2] * dpct::pow(theta_scale, p);
} else if (sector >= sections.v[0] && sector < sec_w) {
const int p = sector - sections.v[0];
theta_base = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p);
theta_base = pos[i2 + ne02] * dpct::pow(theta_scale, p);
}
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
float cos_theta;
float sin_theta;
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[ix + 0];
const float x1 = x[ix + n_dims];
// store results in dst
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta;
}
template <typename T>
static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base,
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
const float * freq_factors, queue_ptr stream) {
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
const sycl::range<3> block_nums(1, num_blocks_x, nr);
template <bool forward, typename T, typename D>
static void
rope_norm_sycl(const T *x, D *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02, const int s03,
const int s1, const int s2, const int s3, const int n_dims,
const int nr, const int32_t *pos, const float freq_scale,
const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float *freq_factors, const int64_t *row_indices,
const int set_rows_stride, dpct::queue_ptr stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x =
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
if (freq_factors == nullptr) {
/*
DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_norm<forward, false>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, row_indices, set_rows_stride);
});
} else {
/*
DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_norm<forward, true>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, row_indices, set_rows_stride);
});
}
}
template <typename T>
static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
const int n_dims, const int nr, const int32_t * pos, const float freq_scale,
const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
const sycl::range<3> block_nums(1, num_blocks_x, nr);
template <bool forward, typename T, typename D>
static void
rope_neox_sycl(const T *x, D *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02, const int s03,
const int s1, const int s2, const int s3, const int n_dims,
const int nr, const int32_t *pos, const float freq_scale,
const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float *freq_factors, const int64_t *row_indices,
const int set_rows_stride, dpct::queue_ptr stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x =
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
if (freq_factors == nullptr) {
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_neox<forward, false>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, row_indices, set_rows_stride);
});
} else {
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_neox<forward, true>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, row_indices, set_rows_stride);
});
}
}
template <typename T>
static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
const float freq_scale, const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
const mrope_sections sections, const bool is_imrope, queue_ptr stream) {
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
template <bool forward, typename T>
static void
rope_multi_sycl(const T *x, T *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02, const int s03,
const int s1, const int s2, const int s3, const int n_dims,
const int nr, const int32_t *pos, const float freq_scale,
const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float *freq_factors, const mrope_sections sections,
const bool is_imrope, dpct::queue_ptr stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x =
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
// Add FP16 capability check if T could be sycl::half
if constexpr (std::is_same_v<T, sycl::half>) {
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
}
// launch kernel
if (freq_factors == nullptr) {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_multi<forward, false, T>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections, is_imrope);
});
} else {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_multi<forward, true, T>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections, is_imrope);
});
}
}
template <bool forward, typename T>
static void
rope_vision_sycl(const T *x, T *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02, const int s03,
const int s1, const int s2, const int s3, const int n_dims,
const int nr, const int32_t *pos, const float freq_scale,
const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float *freq_factors, const mrope_sections sections,
dpct::queue_ptr stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x =
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
// rope vision
template <typename T>
static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
const float freq_scale, const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
const mrope_sections sections, queue_ptr stream) {
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
// Add FP16 capability check if T could be sycl::half
if constexpr (std::is_same_v<T, sycl::half>) {
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
}
// launch kernel
if (freq_factors == nullptr) {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_vision<forward, false, T>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections);
});
} else {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_vision<forward, true, T>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections);
});
}
}
inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
template <bool forward>
void ggml_sycl_op_rope_impl(ggml_backend_sycl_context &ctx, ggml_tensor *dst,
const ggml_tensor *set_rows = nullptr) {
const ggml_tensor *src0 = dst->src[0];
const ggml_tensor *src1 = dst->src[1];
const ggml_tensor *src2 = dst->src[2];
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(dst->src[0]->type == dst->type);
const int64_t ne00 = dst->src[0]->ne[0]; // head dims
const int64_t ne01 = dst->src[0]->ne[1]; // num heads
const int64_t ne02 = dst->src[0]->ne[2]; // num heads
const int64_t nr = ggml_nrows(dst->src[0]);
const float *src0_d = (const float *)src0->data;
const float *src1_d = (const float *)src1->data;
const size_t s01 = dst->src[0]->nb[1] / ggml_type_size(dst->src[0]->type);
const size_t s02 = dst->src[0]->nb[2] / ggml_type_size(dst->src[0]->type);
void *dst_d = dst->data;
const int64_t *row_indices = nullptr;
ggml_type dst_type = dst->type;
int set_rows_stride = 0;
if (set_rows != nullptr) {
GGML_ASSERT(forward);
dst_d = set_rows->data;
row_indices = (const int64_t *)set_rows->src[1]->data;
dst_type = set_rows->type;
set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
}
dpct::queue_ptr stream = ctx.stream();
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
//const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type ||
(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));
const int64_t ne00 = src0->ne[0]; // head dims
const int64_t ne01 = src0->ne[1]; // num heads
const int64_t ne02 = src0->ne[2]; // num heads
const int64_t nr = ggml_nrows(src0);
const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);
const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);
const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);
const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);
const int n_dims = ((int32_t *)dst->op_params)[1];
const int mode = ((int32_t *)dst->op_params)[2];
const int n_ctx_orig = ((int32_t *)dst->op_params)[4];
mrope_sections sections;
// RoPE alteration for extended context
float freq_base;
float freq_scale;
float ext_factor;
@ -382,13 +506,13 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
float beta_fast;
float beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
memcpy(&freq_base, (int32_t *)dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *)dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *)dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *)dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *)dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *)dst->op_params + 10, sizeof(float));
memcpy(&sections.v, (int32_t *)dst->op_params + 11, sizeof(int) * 4);
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
@ -396,82 +520,122 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 ||
sections.v[2] > 0);
}
if (is_vision) {
GGML_ASSERT(n_dims == ne00/2);
GGML_ASSERT(n_dims == ne00 / 2);
}
const int32_t * pos = (const int32_t *) dst->src[1]->data;
const int32_t *pos = (const int32_t *)src1_d;
const float * freq_factors = nullptr;
if (dst->src[2] != nullptr) {
freq_factors = (const float *) dst->src[2]->data;
const float *freq_factors = nullptr;
if (src2 != nullptr) {
freq_factors = (const float *)src2->data;
}
rope_corr_dims corr_dims;
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast,
beta_slow, corr_dims.v);
// compute
if (is_neox) {
GGML_SYCL_DEBUG("%s: neox path\n", __func__);
if (dst->src[0]->type == GGML_TYPE_F32) {
rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
} else if (dst->src[0]->type == GGML_TYPE_F16) {
rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
main_stream);
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
rope_neox_sycl<forward, float, float>(
(const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
rope_neox_sycl<forward, float, sycl::half>(
(const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
row_indices, set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
rope_neox_sycl<forward, sycl::half, sycl::half>(
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
row_indices, set_rows_stride, stream);
} else {
GGML_ABORT("fatal error");
GGML_ABORT("Fatal error: Tensor type unsupported!");
}
} else if (is_mrope && !is_vision) {
GGML_SYCL_DEBUG("%s: mrope path\n", __func__);
if (dst->src[0]->type == GGML_TYPE_F16) {
rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, sections, is_imrope, main_stream);
} else if (dst->src[0]->type == GGML_TYPE_F32) {
rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
is_imrope, main_stream);
if (src0->type == GGML_TYPE_F32) {
rope_multi_sycl<forward>((const float *)src0_d, (float *)dst_d,
ne00, ne01, ne02, s01, s02, s03, s1, s2,
s3, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims,
freq_factors, sections, is_imrope, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_multi_sycl<forward>(
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
sections, is_imrope, stream);
} else {
GGML_ABORT("Fatal error: Tensor type unsupported!");
}
} else if (is_vision) {
GGML_SYCL_DEBUG("%s: vision path\n", __func__);
if (dst->src[0]->type == GGML_TYPE_F16) {
rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01,
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, sections, main_stream);
} else if (dst->src[0]->type == GGML_TYPE_F32) {
rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
main_stream);
if (src0->type == GGML_TYPE_F32) {
rope_vision_sycl<forward>(
(const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, sections,
stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_vision_sycl<forward>(
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
sections, stream);
} else {
GGML_ABORT("Fatal error: Tensor type unsupported!");
}
} else {
GGML_SYCL_DEBUG("%s: norm path\n", __func__);
if (dst->src[0]->type == GGML_TYPE_F32) {
rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
} else if (dst->src[0]->type == GGML_TYPE_F16) {
rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
main_stream);
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
rope_norm_sycl<forward, float, float>(
(const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
rope_norm_sycl<forward, float, sycl::half>(
(const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
row_indices, set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
rope_norm_sycl<forward, sycl::half, sycl::half>(
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
row_indices, set_rows_stride, stream);
} else {
GGML_ABORT("fatal error");
GGML_ABORT("Fatal error: Tensor type unsupported!");
}
}
}
void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
void ggml_sycl_rope(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
ggml_sycl_op_rope(ctx, dst);
ggml_sycl_op_rope_impl<true>(ctx, dst);
}
void ggml_sycl_rope_back(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
ggml_sycl_op_rope_impl<false>(ctx, dst);
}
void ggml_sycl_rope_fused(ggml_backend_sycl_context &ctx, ggml_tensor *rope,
ggml_tensor *set_rows) {
scope_op_debug_print scope_dbg_print(__func__, rope, /*num_src=*/3);
ggml_sycl_op_rope_impl<true>(ctx, rope, set_rows);
}

View File

@ -15,6 +15,12 @@
#include "common.hpp"
#define SYCL_ROPE_BLOCK_SIZE 256
void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
void ggml_sycl_rope_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_rope_fused(ggml_backend_sycl_context & ctx, ggml_tensor * dst, ggml_tensor * set_rows);
#endif // GGML_SYCL_ROPE_HPP

View File

@ -42,11 +42,20 @@
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
// Matrix-vector multiplication parameters
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
// Must be multiple of 4 to work with vectorized paths, and must divide
// mul_mat_vec wg size
#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_TILE_K 256
#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256
// Requires 32 threads per output (wg_size/outputs_per_wg == 32)
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
// Requires at least two (and multiple of 2) k-quant blocks per tile
#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512
// default size for legacy matrix multiplication
#define WEBGPU_MUL_MAT_WG_SIZE 256
@ -199,7 +208,8 @@ struct ggml_webgpu_binary_pipeline_key {
bool src_overlap;
bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap &&
src_overlap == other.src_overlap;
}
};
@ -749,6 +759,36 @@ class ggml_webgpu_shader_lib {
std::vector<std::string> defines;
std::string variant = "mul_mat_vec";
// src0 type (matrix row)
switch (context.src0->type) {
case GGML_TYPE_F32:
defines.push_back("SRC0_INNER_TYPE=f32");
defines.push_back("MUL_ACC_FLOAT");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("SRC0_INNER_TYPE=f16");
defines.push_back("MUL_ACC_FLOAT");
variant += "_f16";
break;
default:
{
// Quantized types: use helpers but accumulate in f16
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
std::string src0_name = src0_traits->type_name;
std::string type_upper = src0_name;
variant += "_" + src0_name;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
defines.push_back("BYTE_HELPERS");
defines.push_back("MUL_ACC_" + type_upper);
// For fast path we always dequantize from f16 inside the shader
defines.push_back("SRC0_INNER_TYPE=f16");
break;
}
}
// src1 type (vector)
switch (context.src1->type) {
case GGML_TYPE_F32:
@ -763,39 +803,21 @@ class ggml_webgpu_shader_lib {
GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
}
// src0 type (matrix row)
switch (context.src0->type) {
case GGML_TYPE_F32:
defines.push_back("SRC0_INNER_TYPE=f32");
defines.push_back("MUL_ACC_FLOAT");
break;
case GGML_TYPE_F16:
defines.push_back("SRC0_INNER_TYPE=f16");
defines.push_back("MUL_ACC_FLOAT");
break;
default:
{
// Quantized types: use helpers but accumulate in f16
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
std::string src0_name = src0_traits->type_name;
std::string type_upper = src0_name;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
defines.push_back("BYTE_HELPERS");
defines.push_back("MUL_ACC_" + type_upper);
// For fast path we always dequantize from f16 inside the shader
defines.push_back("SRC0_INNER_TYPE=f16");
break;
}
}
// VEC/SCALAR controls
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_TILE_K;
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
if (key.src0_type >= GGML_TYPE_Q2_K) {
tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
@ -1061,10 +1083,10 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_binary_pipeline_key key = {
.type = context.dst->type,
.op = context.dst->op,
.inplace = context.inplace,
.overlap = context.overlap,
.type = context.dst->type,
.op = context.dst->op,
.inplace = context.inplace,
.overlap = context.overlap,
.src_overlap = context.src_overlap,
};

View File

@ -8,7 +8,6 @@
#include "ggml-backend-impl.h"
#include "ggml-impl.h"
#include "ggml-webgpu-shader-lib.hpp"
#include "pre_wgsl.hpp"
#ifdef __EMSCRIPTEN__
# include <emscripten/emscripten.h>
@ -20,12 +19,18 @@
#include <condition_variable>
#include <cstdint>
#include <cstring>
#include <iostream>
#ifdef GGML_WEBGPU_GPU_PROFILE
# include <iomanip>
#endif
#if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE)
# include <iostream>
#endif
#include <map>
#include <memory>
#include <mutex>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
@ -70,22 +75,21 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
#endif // GGML_WEBGPU_CPU_PROFILE
#ifdef GGML_WEBGPU_GPU_PROFILE
# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24
# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 32
# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps
#endif
/* Constants */
#define WEBGPU_NUM_PARAM_BUFS 48u
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16u
#define WEBGPU_NUM_PARAM_BUFS 96u
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
// Maximum number of in-flight submissions per-thread, to avoid exhausting the
// parameter buffer pool
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE)
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
// For operations which process a row in parallel, this seems like a reasonable
// default
@ -118,14 +122,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
wgpu::BufferUsage usage,
const char * label);
struct webgpu_pool_bufs {
wgpu::Buffer host_buf;
wgpu::Buffer dev_buf;
};
// Holds a pool of parameter buffers for WebGPU operations
struct webgpu_buf_pool {
std::vector<webgpu_pool_bufs> free;
std::vector<wgpu::Buffer> free;
// The pool must be synchronized because
// 1. The memset pool is shared globally by every ggml buffer,
@ -138,7 +137,6 @@ struct webgpu_buf_pool {
size_t cur_pool_size;
size_t max_pool_size;
wgpu::Device device;
wgpu::BufferUsage host_buf_usage;
wgpu::BufferUsage dev_buf_usage;
size_t buf_size;
bool should_grow;
@ -147,53 +145,47 @@ struct webgpu_buf_pool {
int num_bufs,
size_t buf_size,
wgpu::BufferUsage dev_buf_usage,
wgpu::BufferUsage host_buf_usage,
bool should_grow = false,
size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) {
this->max_pool_size = max_pool_size;
this->cur_pool_size = num_bufs;
this->device = device;
this->host_buf_usage = host_buf_usage;
this->dev_buf_usage = dev_buf_usage;
this->buf_size = buf_size;
this->should_grow = should_grow;
this->max_pool_size = max_pool_size;
this->cur_pool_size = num_bufs;
this->device = device;
this->dev_buf_usage = dev_buf_usage;
this->buf_size = buf_size;
this->should_grow = should_grow;
for (int i = 0; i < num_bufs; i++) {
wgpu::Buffer host_buf;
wgpu::Buffer dev_buf;
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
free.push_back({ host_buf, dev_buf });
free.push_back(dev_buf);
}
}
webgpu_pool_bufs alloc_bufs() {
wgpu::Buffer alloc_bufs() {
std::unique_lock<std::mutex> lock(mutex);
if (!free.empty()) {
webgpu_pool_bufs bufs = free.back();
wgpu::Buffer buf = free.back();
free.pop_back();
return bufs;
return buf;
}
// Try growing the pool if no free buffers
if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
cur_pool_size++;
wgpu::Buffer host_buf;
wgpu::Buffer dev_buf;
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
if (!(host_buf && dev_buf)) {
if (!dev_buf) {
GGML_ABORT("webgpu_buf_pool: failed to allocate buffers");
}
return webgpu_pool_bufs{ host_buf, dev_buf };
return dev_buf;
}
cv.wait(lock, [this] { return !free.empty(); });
webgpu_pool_bufs bufs = free.back();
wgpu::Buffer buf = free.back();
free.pop_back();
return bufs;
return buf;
}
void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
void free_bufs(std::vector<wgpu::Buffer> bufs) {
std::lock_guard<std::mutex> lock(mutex);
free.insert(free.end(), bufs.begin(), bufs.end());
cv.notify_all();
@ -201,12 +193,9 @@ struct webgpu_buf_pool {
void cleanup() {
std::lock_guard<std::mutex> lock(mutex);
for (auto & bufs : free) {
if (bufs.host_buf) {
bufs.host_buf.Destroy();
}
if (bufs.dev_buf) {
bufs.dev_buf.Destroy();
for (auto & buf : free) {
if (buf) {
buf.Destroy();
}
}
free.clear();
@ -280,10 +269,9 @@ struct webgpu_gpu_profile_buf_pool {
#endif
struct webgpu_command {
uint32_t num_kernels;
wgpu::CommandBuffer commands;
std::vector<webgpu_pool_bufs> params_bufs;
std::optional<webgpu_pool_bufs> set_rows_error_bufs;
uint32_t num_kernels;
wgpu::CommandBuffer commands;
std::vector<wgpu::Buffer> params_bufs;
#ifdef GGML_WEBGPU_GPU_PROFILE
webgpu_gpu_profile_bufs timestamp_query_bufs;
std::string pipeline_name;
@ -358,6 +346,13 @@ struct webgpu_global_context_struct {
typedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context;
struct webgpu_submission {
wgpu::FutureWaitInfo submit_done;
#ifdef GGML_WEBGPU_GPU_PROFILE
std::vector<wgpu::FutureWaitInfo> profile_futures;
#endif
};
// All the base objects needed to run operations on a WebGPU device
struct webgpu_context_struct {
// Points to global instances owned by ggml_backend_webgpu_reg_context
@ -366,7 +361,8 @@ struct webgpu_context_struct {
std::unique_ptr<ggml_webgpu_shader_lib> shader_lib;
webgpu_buf_pool param_buf_pool;
webgpu_buf_pool set_rows_error_buf_pool;
wgpu::Buffer set_rows_dev_error_buf;
wgpu::Buffer set_rows_host_error_buf;
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
@ -458,67 +454,105 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
/** End WebGPU object initializations */
/** WebGPU Actions */
static void erase_completed(std::vector<wgpu::FutureWaitInfo> & futures) {
static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) {
switch (status) {
case wgpu::WaitStatus::Success:
return true;
case wgpu::WaitStatus::TimedOut:
if (allow_timeout) {
return false;
}
GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n");
return false;
case wgpu::WaitStatus::Error:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
return false;
default:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
return false;
}
}
#ifdef GGML_WEBGPU_GPU_PROFILE
static void ggml_backend_webgpu_erase_completed_futures(std::vector<wgpu::FutureWaitInfo> & futures) {
futures.erase(std::remove_if(futures.begin(), futures.end(),
[](const wgpu::FutureWaitInfo & info) { return info.completed; }),
futures.end());
}
// Wait for the queue to finish processing all submitted work
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
std::vector<wgpu::FutureWaitInfo> & futures,
bool block = true) {
// If we have too many in-flight submissions, wait on the oldest one first.
static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx,
std::vector<wgpu::FutureWaitInfo> & futures,
bool block) {
if (futures.empty()) {
return;
}
uint64_t timeout_ms = block ? UINT64_MAX : 0;
while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
auto waitStatus = ctx->instance.WaitAny(1, &futures[0], UINT64_MAX);
if (waitStatus == wgpu::WaitStatus::Error) {
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
if (block) {
while (!futures.empty()) {
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
ggml_backend_webgpu_erase_completed_futures(futures);
}
}
if (futures[0].completed) {
futures.erase(futures.begin());
} else {
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
ggml_backend_webgpu_erase_completed_futures(futures);
}
}
}
#endif
// Wait for the queue to finish processing all submitted work
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
std::vector<webgpu_submission> & subs,
bool block = true) {
// If we have too many in-flight submissions, wait on the oldest one first.
if (subs.empty()) {
return;
}
while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX);
if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
#ifdef GGML_WEBGPU_GPU_PROFILE
ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);
#endif
subs.erase(subs.begin());
}
}
if (futures.empty()) {
if (subs.empty()) {
return;
}
if (block) {
while (!futures.empty()) {
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
switch (waitStatus) {
case wgpu::WaitStatus::Success:
// WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
erase_completed(futures);
break;
case wgpu::WaitStatus::Error:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
break;
default:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
break;
for (auto & sub : subs) {
while (!sub.submit_done.completed) {
auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX);
ggml_backend_webgpu_handle_wait_status(waitStatus);
}
#ifdef GGML_WEBGPU_GPU_PROFILE
ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true);
#endif
}
subs.clear();
} else {
// Poll once and return
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
switch (waitStatus) {
case wgpu::WaitStatus::Success:
// WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
erase_completed(futures);
break;
case wgpu::WaitStatus::TimedOut:
break;
case wgpu::WaitStatus::Error:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
break;
default:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
break;
// Poll each submit future once and remove completed submissions.
for (auto sub = subs.begin(); sub != subs.end();) {
auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0);
ggml_backend_webgpu_handle_wait_status(waitStatus, true);
#ifdef GGML_WEBGPU_GPU_PROFILE
ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false);
if (sub->submit_done.completed && sub->profile_futures.empty()) {
#else
if (sub->submit_done.completed) {
#endif
sub = subs.erase(sub);
} else {
++sub;
}
}
}
}
@ -554,14 +588,12 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
}
#endif
static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
webgpu_global_context ctx,
std::vector<webgpu_command> commands,
webgpu_buf_pool & param_buf_pool,
webgpu_buf_pool * set_rows_error_buf_pool = nullptr) {
static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & ctx,
std::vector<webgpu_command> & commands,
webgpu_buf_pool & param_buf_pool) {
std::vector<wgpu::CommandBuffer> command_buffers;
std::vector<webgpu_pool_bufs> params_bufs;
std::vector<webgpu_pool_bufs> set_rows_error_bufs;
std::vector<wgpu::Buffer> params_bufs;
webgpu_submission submission;
#ifdef GGML_WEBGPU_GPU_PROFILE
std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;
#endif
@ -569,14 +601,9 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
for (const auto & command : commands) {
command_buffers.push_back(command.commands);
params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end());
if (command.set_rows_error_bufs) {
set_rows_error_bufs.push_back(command.set_rows_error_bufs.value());
}
}
ctx->queue.Submit(command_buffers.size(), command_buffers.data());
std::vector<wgpu::FutureWaitInfo> futures;
wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous,
[&param_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
@ -586,27 +613,7 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
// Free the staged buffers
param_buf_pool.free_bufs(params_bufs);
});
futures.push_back({ p_f });
for (const auto & bufs : set_rows_error_bufs) {
wgpu::Future f = bufs.host_buf.MapAsync(
wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
[set_rows_error_buf_pool, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
if (status != wgpu::MapAsyncStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
} else {
const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange();
if (*error_data) {
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
}
// We can't unmap in here due to WebGPU reentrancy limitations.
if (set_rows_error_buf_pool) {
set_rows_error_buf_pool->free_bufs({ bufs });
}
}
});
futures.push_back({ f });
}
submission.submit_done = { p_f };
#ifdef GGML_WEBGPU_GPU_PROFILE
for (const auto & command : commands) {
@ -623,14 +630,14 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
// WebGPU timestamps are in ns; convert to ms
double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
ctx->shader_gpu_time_ms[label] += elapsed_ms;
// We can't unmap in here due to WebGPU reentrancy limitations.
ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
}
// We can't unmap in here due to WebGPU reentrancy limitations.
ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
});
futures.push_back({ f });
submission.profile_futures.push_back({ f });
}
#endif
return futures;
return submission;
}
static webgpu_command ggml_backend_webgpu_build_multi(
@ -639,32 +646,21 @@ static webgpu_command ggml_backend_webgpu_build_multi(
const std::vector<webgpu_pipeline> & pipelines,
const std::vector<std::vector<uint32_t>> & params_list,
const std::vector<std::vector<wgpu::BindGroupEntry>> & bind_group_entries_list,
const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list,
const std::optional<webgpu_pool_bufs> & set_rows_error_bufs = std::nullopt) {
const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list) {
GGML_ASSERT(pipelines.size() == params_list.size());
GGML_ASSERT(pipelines.size() == bind_group_entries_list.size());
GGML_ASSERT(pipelines.size() == workgroups_list.size());
std::vector<webgpu_pool_bufs> params_bufs_list;
std::vector<wgpu::BindGroup> bind_groups;
std::vector<wgpu::Buffer> params_bufs_list;
std::vector<wgpu::BindGroup> bind_groups;
for (size_t i = 0; i < pipelines.size(); i++) {
webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs();
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0,
params_bufs.host_buf.GetSize());
uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
for (size_t j = 0; j < params_list[i].size(); j++) {
_params[j] = params_list[i][j];
}
params_bufs.host_buf.Unmap();
wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs();
std::vector<wgpu::BindGroupEntry> entries = bind_group_entries_list[i];
uint32_t params_binding_num = entries.size();
entries.push_back({ .binding = params_binding_num,
.buffer = params_bufs.dev_buf,
.offset = 0,
.size = params_bufs.dev_buf.GetSize() });
entries.push_back(
{ .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() });
wgpu::BindGroupDescriptor bind_group_desc;
bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0);
@ -677,15 +673,8 @@ static webgpu_command ggml_backend_webgpu_build_multi(
}
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
for (const auto & params_bufs : params_bufs_list) {
encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
}
// If there are SET_ROWS operations in this submission, copy their error
// buffers to the host.
if (set_rows_error_bufs) {
encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
set_rows_error_bufs->host_buf.GetSize());
for (size_t i = 0; i < params_bufs_list.size(); i++) {
ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t));
}
#ifdef GGML_WEBGPU_GPU_PROFILE
@ -718,7 +707,6 @@ static webgpu_command ggml_backend_webgpu_build_multi(
webgpu_command result = {};
result.commands = commands;
result.params_bufs = params_bufs_list;
result.set_rows_error_bufs = set_rows_error_bufs;
result.num_kernels = pipelines.size();
#ifdef GGML_WEBGPU_GPU_PROFILE
result.timestamp_query_bufs = ts_bufs;
@ -734,13 +722,13 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_global_context &
std::vector<uint32_t> params,
std::vector<wgpu::BindGroupEntry> bind_group_entries,
uint32_t wg_x,
uint32_t wg_y = 1,
std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) {
uint32_t wg_y = 1) {
return ggml_backend_webgpu_build_multi(ctx, param_buf_pool,
{
pipeline
},
{ params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs);
{ std::move(params) }, { std::move(bind_group_entries) },
{ { wg_x, wg_y } });
}
static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
@ -757,8 +745,9 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
webgpu_command command =
ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
auto futures = ggml_backend_webgpu_submit(ctx, { command }, ctx->memset_buf_pool);
ggml_backend_webgpu_wait(ctx, futures);
std::vector<webgpu_command> commands = { command };
std::vector<webgpu_submission> sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) };
ggml_backend_webgpu_wait(ctx, sub);
}
/** End WebGPU Actions */
@ -805,7 +794,8 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
std::cout << "\nggml_webgpu: gpu breakdown:\n";
for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2)
<< pct << "%)\n";
}
#endif
@ -978,14 +968,6 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
auto * decisions = static_cast<ggml_webgpu_set_rows_shader_decisions *>(pipeline.context.get());
std::optional<webgpu_pool_bufs> error_bufs = std::nullopt;
if (decisions->i64_idx) {
error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
error_bufs->host_buf.Unmap();
}
}
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
@ -1018,8 +1000,10 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
};
if (decisions->i64_idx) {
entries.push_back(
{ .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() });
entries.push_back({ .binding = 3,
.buffer = ctx->set_rows_dev_error_buf,
.offset = 0,
.size = ctx->set_rows_dev_error_buf.GetSize() });
}
uint32_t threads;
@ -1029,8 +1013,7 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
}
uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1,
error_bufs);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1);
}
// Workgroup size is a common constant
@ -1108,12 +1091,26 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
use_fast = (src0->type == GGML_TYPE_F16);
break;
case GGML_TYPE_F32:
// TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K
switch (src0->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q6_K:
use_fast = true;
break;
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
// we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat
use_fast = !is_vec;
break;
default:
break;
}
@ -1187,17 +1184,18 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
if (use_fast && is_vec) {
auto decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
uint32_t batches = dst->ne[2] * dst->ne[3];
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
uint32_t total_wg = output_groups * batches;
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
} else if (use_fast) {
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
auto * decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
// Fast-path tiled/subgroup calculations
uint32_t wg_m, wg_n;
uint32_t wg_m;
uint32_t wg_n;
if (decisions->use_subgroup_matrix) {
uint32_t wg_m_sg_tile =
decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m;
@ -1215,7 +1213,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
} else { // legacy
auto decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_size = decisions->wg_size;
uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
@ -1514,10 +1512,10 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
}
static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
uint32_t ne = (uint32_t) ggml_nelements(dst);
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
uint32_t ne = (uint32_t) ggml_nelements(dst);
uint32_t dim = (uint32_t) dst->op_params[0];
std::vector<uint32_t> params = {
@ -1538,28 +1536,22 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
dim,
(uint32_t)src0->ne[dim]
(uint32_t) src0->ne[dim]
};
std::vector<wgpu::BindGroupEntry> entries = {
{
.binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0)
},
{
.binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1)
},
{
.binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst)
}
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
ggml_webgpu_shader_lib_context shader_lib_ctx = {
@ -1569,9 +1561,9 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
@ -1623,7 +1615,12 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
float freq_base;
float freq_scale;
float ext_factor;
float attn_factor;
float beta_fast;
float beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
@ -2172,19 +2169,12 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
case GGML_OP_SOFT_MAX:
return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
case GGML_OP_UNARY:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_CLAMP:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_FILL:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_LOG:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_SQR:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_SQRT:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_SIN:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_COS:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_PAD:
@ -2192,7 +2182,6 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
case GGML_OP_ARGMAX:
return ggml_webgpu_argmax(ctx, src0, node);
case GGML_OP_ARGSORT:
return ggml_webgpu_argsort(ctx, src0, node);
case GGML_OP_TOP_K:
// we reuse the same argsort implementation for top_k
return ggml_webgpu_argsort(ctx, src0, node);
@ -2214,33 +2203,51 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
std::vector<webgpu_command> commands;
std::vector<wgpu::FutureWaitInfo> futures;
uint32_t num_batched_kernels = 0;
std::vector<webgpu_command> commands;
std::vector<webgpu_submission> subs;
uint32_t num_batched_kernels = 0;
bool contains_set_rows = false;
for (int i = 0; i < cgraph->n_nodes; i++) {
if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
contains_set_rows = true;
}
if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
commands.push_back(*cmd);
num_batched_kernels += cmd.value().num_kernels;
}
if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
num_batched_kernels = 0;
std::vector<wgpu::FutureWaitInfo> compute_futures = ggml_backend_webgpu_submit(
ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
futures.insert(futures.end(), compute_futures.begin(), compute_futures.end());
num_batched_kernels = 0;
subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
// Process events and check for completed submissions
ctx->global_ctx->instance.ProcessEvents();
ggml_backend_webgpu_wait(ctx->global_ctx, futures, false);
ggml_backend_webgpu_wait(ctx->global_ctx, subs, false);
commands.clear();
}
}
if (!commands.empty()) {
auto new_futures =
ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
futures.insert(futures.end(), new_futures.begin(), new_futures.end());
subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
commands.clear();
}
ggml_backend_webgpu_wait(ctx->global_ctx, futures);
// If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking.
if (contains_set_rows) {
wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder();
encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0,
ctx->set_rows_host_error_buf.GetSize());
wgpu::CommandBuffer set_rows_commands = encoder.Finish();
ctx->global_ctx->queue.Submit(1, &set_rows_commands);
ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0,
ctx->set_rows_host_error_buf.GetSize());
const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange();
if (*error_data) {
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
}
ctx->set_rows_host_error_buf.Unmap();
}
ggml_backend_webgpu_wait(ctx->global_ctx, subs);
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
return GGML_STATUS_SUCCESS;
}
@ -2859,10 +2866,12 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true);
webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf,
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf");
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_host_error_buf,
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);

View File

@ -11,7 +11,7 @@ fn store_shmem(val: vec4<f16>, idx: u32) {
shmem[idx + 2] = val.z;
shmem[idx + 3] = val.w;
}
#endif
#endif // VEC
#ifdef SCALAR
#define VEC_SIZE 1
@ -23,7 +23,7 @@ fn store_shmem(val: vec4<f16>, idx: u32) {
fn store_shmem(val: f16, idx: u32) {
shmem[idx] = val;
}
#endif
#endif // SCALAR
#ifdef INIT_SRC0_SHMEM_FLOAT
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
@ -40,7 +40,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
store_shmem(SHMEM_TYPE(src0_val), elem_idx);
}
}
#endif
#endif // INIT_SRC0_SHMEM_FLOAT
#ifdef INIT_SRC1_SHMEM_FLOAT
fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
@ -57,7 +57,7 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx);
}
}
#endif
#endif // INIT_SRC1_SHMEM_FLOAT
#ifdef INIT_SRC0_SHMEM_Q4_0
const BLOCK_SIZE = 32u;
@ -100,4 +100,667 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
}
}
}
#endif
#endif // INIT_SRC0_SHMEM_Q4_0
#ifdef INIT_SRC0_SHMEM_Q4_1
const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let m = src0[scale_idx + 1u];
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_lo = f16(q_byte & 0xF) * d + m;
let q_hi = f16((q_byte >> 4) & 0xF) * d + m;
shmem[shmem_idx + j * 2 + k] = q_lo;
shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q4_1
#ifdef INIT_SRC0_SHMEM_Q5_0
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
// tile_k is defined as 32u, so blocks_k ends up being 1 always
override BLOCKS_K = TILE_K / BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let qh0 = src0[scale_idx + 1u];
let qh1 = src0[scale_idx + 2u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let j_adjusted = j + (block_offset / 2u);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
let q_lo = (f16((q_byte & 0xF) | qh_lo) - 16.0) * d;
shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight
shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q5_0
#ifdef INIT_SRC0_SHMEM_Q5_1
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
// tile_k is defined as 32u, so blocks_k ends up being 1 always
override BLOCKS_K = TILE_K / BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let m = src0[scale_idx + 1u];
let qh0 = src0[scale_idx + 2u];
let qh1 = src0[scale_idx + 3u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let j_adjusted = j + (block_offset / 2u);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi)) * d + m;
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
let q_lo = (f16((q_byte & 0xF) | qh_lo)) * d + m;
shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight
shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q5_1
#ifdef INIT_SRC0_SHMEM_Q8_0
const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
let q_0 = src0[scale_idx + 1u + block_offset + j];
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
shmem[shmem_idx + j * 2 + k] = q_val;
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q8_0
#ifdef INIT_SRC0_SHMEM_Q8_1
const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let m = src0[scale_idx + 1u];
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d + m;
shmem[shmem_idx + j * 2 + k] = q_val;
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q8_1
#ifdef INIT_SRC0_SHMEM_Q2_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 42u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
// Use standard thread layout instead of lane/row_group
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx + 40u];
let dmin = src0[scale_idx + 41u];
// Decode the element at position k_in_block
let block_of_32 = k_in_block / 32u;
let pos_in_32 = k_in_block % 32u;
let q_b_idx = (block_of_32 / 4u) * 32u;
let shift = (block_of_32 % 4u) * 2u;
let k = (pos_in_32 / 16u) * 16u;
let l = pos_in_32 % 16u;
let is = k_in_block / 16u;
let sc_0 = src0[scale_idx + 2u * (is / 4u)];
let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u];
let sc_packed = bitcast<u32>(vec2(sc_0, sc_1));
let sc = get_byte(sc_packed, is % 4u);
let dl = d * f16(sc & 0xFu);
let ml = dmin * f16(sc >> 4u);
let q_idx = q_b_idx + k + l;
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 3u;
let q_val = f16(qs_val) * dl - ml;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q2_K
#ifdef INIT_SRC0_SHMEM_Q3_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 55u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx + 54u];
// Load and unpack scales
let kmask1: u32 = 0x03030303u;
let kmask2: u32 = 0x0f0f0f0fu;
var scale_vals: array<u32, 4>;
for (var i: u32 = 0u; i < 4u; i++) {
let scale_0 = src0[scale_idx + 48u + (2u*i)];
let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u];
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
}
var tmp: u32 = scale_vals[2];
scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u);
scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u);
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u);
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u);
// Load hmask and qs arrays
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0u; i < 8u; i++) {
let hmask_0 = src0[scale_idx + (2u*i)];
let hmask_1 = src0[scale_idx + (2u*i) + 1u];
hmask_vals[i] = bitcast<u32>(vec2(hmask_0, hmask_1));
}
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16u; i++) {
let qs_0 = src0[scale_idx + 16u + (2u*i)];
let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u];
qs_vals[i] = bitcast<u32>(vec2(qs_0, qs_1));
}
let half = k_in_block / 128u; // 0 or 1
let pos_in_half = k_in_block % 128u; // 0-127
let shift_group = pos_in_half / 32u; // 0-3
let pos_in_32 = pos_in_half % 32u; // 0-31
let k_group = pos_in_32 / 16u; // 0 or 1
let l = pos_in_32 % 16u; // 0-15
let q_b_idx = half * 32u; // 0 or 32
let shift = shift_group * 2u; // 0, 2, 4, 6
let k = k_group * 16u; // 0 or 16
let is = k_in_block / 16u; // 0-15
// m increments every 32 elements across entire 256 element block
let m_shift = k_in_block / 32u; // 0-7
let m: u32 = 1u << m_shift; // 1,2,4,8,16,32,64,128
let sc = get_byte(scale_vals[is / 4u], is % 4u);
let dl = d * (f16(sc) - 32.0);
let q_idx = q_b_idx + k + l;
let hm_idx = k + l;
let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u);
let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u);
let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
let qs_val = (q_byte >> shift) & 3u;
let q_val = (f16(qs_val) - f16(hm)) * dl;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q3_K
#ifdef INIT_SRC0_SHMEM_Q4_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 72u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let dmin = src0[scale_idx + 1u];
// Load packed scales
var scale_vals: array<u32, 3>;
for (var i: u32 = 0u; i < 3u; i++) {
let scale_0 = src0[scale_idx + 2u + (2u*i)];
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
}
// Map k_in_block to loop structure:
// Outer loop over 64-element groups (alternating q_b_idx)
// Inner loop over 2 shifts per group
let group_of_64 = k_in_block / 64u; // 0-3 (maps to q_b_idx)
let pos_in_64 = k_in_block % 64u; // 0-63
let shift_group = pos_in_64 / 32u; // 0 or 1
let l = pos_in_64 % 32u; // 0-31
let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
let shift = shift_group * 4u; // 0 or 4
let is = k_in_block / 32u; // 0-7
var sc: u32;
var mn: u32;
if (is < 4u) {
let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
} else {
let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
}
let dl = d * f16(sc);
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 0xFu;
let q_val = f16(qs_val) * dl - ml;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q4_K
#ifdef INIT_SRC0_SHMEM_Q5_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 88u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let dmin = src0[scale_idx + 1u];
// Load packed scales
var scale_vals: array<u32, 3>;
for (var i: u32 = 0u; i < 3u; i++) {
let scale_0 = src0[scale_idx + 2u + (2u*i)];
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
}
// The original loop processes elements in groups of 64
// Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]
// But u increments EVERY 32 elements (after each l loop)
let group_of_64 = k_in_block / 64u; // 0-3
let pos_in_64 = k_in_block % 64u; // 0-63
let shift_group = pos_in_64 / 32u; // 0 or 1
let l = pos_in_64 % 32u; // 0-31
let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
let shift = shift_group * 4u; // 0 or 4
let is = k_in_block / 32u; // 0-7
// u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128)
let u_shift = k_in_block / 32u; // 0-7
let u: u32 = 1u << u_shift;
var sc: u32;
var mn: u32;
if (is < 4u) {
let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
} else {
let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
}
let dl = d * f16(sc);
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)];
let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)];
let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u];
let qh_packed = bitcast<u32>(vec2(qh_0, qh_1));
let qh_byte = get_byte(qh_packed, l % 4u);
let qs_val = (q_byte >> shift) & 0xFu;
let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q5_K
#ifdef INIT_SRC0_SHMEM_Q6_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 105u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let half = k_in_block / 128u;
let pos_in_half = k_in_block % 128u;
let quarter = pos_in_half / 32u;
let l = pos_in_half % 32u;
let ql_b_idx = half * 64u;
let qh_b_idx = half * 32u;
let sc_b_idx = half * 8u;
// Load only ql13 word needed
let ql13_flat = ql_b_idx + l;
let ql13_word = ql13_flat / 4u;
let ql13 = bitcast<u32>(vec2(
src0[scale_idx + 2u * ql13_word],
src0[scale_idx + 2u * ql13_word + 1u]
));
let ql13_b = get_byte(ql13, ql13_flat % 4u);
// Load only ql24 word needed
let ql24_flat = ql_b_idx + l + 32u;
let ql24_word = ql24_flat / 4u;
let ql24 = bitcast<u32>(vec2(
src0[scale_idx + 2u * ql24_word],
src0[scale_idx + 2u * ql24_word + 1u]
));
let ql24_b = get_byte(ql24, ql24_flat % 4u);
// Load only qh word needed
let qh_flat = qh_b_idx + l;
let qh_word = qh_flat / 4u;
let qh = bitcast<u32>(vec2(
src0[scale_idx + 64u + 2u * qh_word],
src0[scale_idx + 64u + 2u * qh_word + 1u]
));
let qh_b = get_byte(qh, qh_flat % 4u);
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0);
let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0);
// Load only the scale word needed
let is = l / 16u;
let sc_idx = sc_b_idx + is + quarter * 2u;
let sc_word = sc_idx / 4u;
let sc = bitcast<u32>(vec2(
src0[scale_idx + 96u + 2u * sc_word],
src0[scale_idx + 96u + 2u * sc_word + 1u]
));
let sc_val = get_byte_i32(sc, sc_idx % 4u);
let d = src0[scale_idx + 104u];
var q_val: f16;
if (quarter == 0u) {
q_val = q1;
} else if (quarter == 1u) {
q_val = q2;
} else if (quarter == 2u) {
q_val = q3;
} else {
q_val = q4;
}
shmem[elem_idx] = d * f16(sc_val) * q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q6_K

View File

@ -50,6 +50,7 @@ fn get_local_m(thread_id: u32) -> u32 {
const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)

View File

@ -1,4 +1,3 @@
enable f16;
#include "common_decls.tmpl"
@ -84,6 +83,294 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
}
#endif
#ifdef MUL_ACC_Q4_1
const BLOCK_SIZE = 32;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 10u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let m = f32(src0[scale_idx + 1u]);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
let q_lo = f32(q_byte & 0xF) * d + m;
local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];
local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];
}
}
}
return local_sum;
}
#endif
#ifdef MUL_ACC_Q5_0
const BLOCK_SIZE = 32;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 11u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let qh0 = src0[scale_idx + 1u];
let qh1 = src0[scale_idx + 2u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let j_adjusted = j + (block_offset / 2u);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k];
local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16];
}
}
}
return local_sum;
}
#endif
#ifdef MUL_ACC_Q5_1
const BLOCK_SIZE = 32;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 12u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let m = src0[scale_idx + 1u];
let qh0 = src0[scale_idx + 2u];
let qh1 = src0[scale_idx + 3u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let j_adjusted = j + (block_offset / 2u);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + f32(m);
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
let q_lo = f32((q_byte & 0xF) | qh_lo) * d + f32(m);
local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k];
local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16];
}
}
}
return local_sum;
}
#endif
#ifdef MUL_ACC_Q8_0
const BLOCK_SIZE = 32;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 17u;
const WEIGHTS_PER_F16 = 2u;
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 1 + block_offset + j];
let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
local_sum += q_val * shared_vector[shmem_idx + j * 2 + k];
}
}
}
return local_sum;
}
#endif
#ifdef MUL_ACC_Q8_1
const BLOCK_SIZE = 32;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 18u;
const WEIGHTS_PER_F16 = 2u;
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let m = src0[scale_idx + 1u];
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d + f32(m);
local_sum += q_val * shared_vector[shmem_idx + j * 2 + k];
}
}
}
return local_sum;
}
#endif
#ifdef MUL_ACC_Q6_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 105u;
fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 {
let aligned = byte_offset & ~3u;
let idx = bbase + aligned / 2u;
return bitcast<u32>(vec2(src0[idx], src0[idx + 1u]));
}
fn byte_of(v: u32, b: u32) -> u32 {
return (v >> (b * 8u)) & 0xFFu;
}
fn sbyte_of(v: u32, b: u32) -> i32 {
let raw = i32((v >> (b * 8u)) & 0xFFu);
return select(raw, raw - 256, raw >= 128);
}
fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let tid = tig / 2u;
let ix = tig % 2u;
let ip = tid / 8u;
let il = tid % 8u;
let l0 = 4u * il;
let is = 8u * ip + l0 / 16u;
let y_offset = 128u * ip + l0;
let q_offset_l = 64u * ip + l0;
let q_offset_h = 32u * ip + l0;
let nb = tile_size / BLOCK_SIZE;
let k_block_start = k_outer / BLOCK_SIZE;
// Aligned scale byte position (is can be odd)
let sc_base_byte = 192u + (is & ~3u);
let sc_byte_pos = is & 3u;
var local_sum = 0.0;
for (var i = ix; i < nb; i += 2u) {
let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK;
let d_raw = load_u32_at(bbase, 208u);
let d = f32(bitcast<vec2<f16>>(d_raw)[0]);
let ql1_u32 = load_u32_at(bbase, q_offset_l);
let ql2_u32 = load_u32_at(bbase, q_offset_l + 32u);
let qh_u32 = load_u32_at(bbase, 128u + q_offset_h);
let sc_u32_0 = load_u32_at(bbase, sc_base_byte);
let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u);
let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);
let sc4 = sbyte_of(sc_u32_1, sc_byte_pos);
let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u);
var sums = vec4<f32>(0.0, 0.0, 0.0, 0.0);
for (var l = 0u; l < 4u; l++) {
let y_base = i * BLOCK_SIZE + y_offset + l;
let yl0 = f32(shared_vector[y_base]);
let yl1 = f32(shared_vector[y_base + 32u]);
let yl2 = f32(shared_vector[y_base + 64u]);
let yl3 = f32(shared_vector[y_base + 96u]);
let q1b = byte_of(ql1_u32, l);
let q2b = byte_of(ql2_u32, l);
let qhb = byte_of(qh_u32, l);
let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32);
let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32);
let dq2 = f32(i32((q1b >> 4u) | ((qhb & 0x30u) )) - 32);
let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32);
sums[0] += yl0 * dq0;
sums[1] += yl1 * dq1;
sums[2] += yl2 * dq2;
sums[3] += yl3 * dq3;
}
local_sum += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) +
sums[2] * f32(sc4) + sums[3] * f32(sc6));
}
return local_sum;
}
#endif
struct MulMatParams {
offset_src0: u32,
offset_src1: u32,
@ -191,4 +478,3 @@ fn main(
dst[dst_idx / VEC_SIZE] = store_val(group_base);
}
}

View File

@ -177,6 +177,8 @@ class Keys:
TEMPERATURE_LENGTH = "{arch}.attention.temperature_length"
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
KEY_LENGTH_SWA = "{arch}.attention.key_length_swa"
VALUE_LENGTH_SWA = "{arch}.attention.value_length_swa"
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"
TEMPERATURE_SCALE = "{arch}.attention.temperature_scale"
@ -188,6 +190,7 @@ class Keys:
class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
DIMENSION_COUNT_SWA = "{arch}.rope.dimension_count_swa"
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
FREQ_BASE = "{arch}.rope.freq_base"
FREQ_BASE_SWA = "{arch}.rope.freq_base_swa"

View File

@ -773,6 +773,12 @@ class GGUFWriter:
def add_value_length_mla(self, length: int) -> None:
self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length)
def add_key_length_swa(self, length: int) -> None:
self.add_uint32(Keys.Attention.KEY_LENGTH_SWA.format(arch=self.arch), length)
def add_value_length_swa(self, length: int) -> None:
self.add_uint32(Keys.Attention.VALUE_LENGTH_SWA.format(arch=self.arch), length)
def add_indexer_head_count(self, count: int) -> None:
self.add_uint32(Keys.Attention.Indexer.HEAD_COUNT.format(arch=self.arch), count)
@ -946,6 +952,9 @@ class GGUFWriter:
def add_rope_dimension_count(self, count: int) -> None:
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
def add_rope_dimension_count_swa(self, count: int) -> None:
self.add_uint32(Keys.Rope.DIMENSION_COUNT_SWA.format(arch=self.arch), count)
def add_rope_dimension_sections(self, dims: Sequence[int]) -> None:
self.add_array(Keys.Rope.DIMENSION_SECTIONS.format(arch=self.arch), dims)

View File

@ -5,7 +5,7 @@ import os
import sys
import subprocess
HTTPLIB_VERSION = "refs/tags/v0.35.0"
HTTPLIB_VERSION = "refs/tags/v0.37.0"
vendor = {
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
@ -15,7 +15,7 @@ vendor = {
# not using latest tag to avoid this issue: https://github.com/ggml-org/llama.cpp/pull/17179#discussion_r2515877926
# "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.24/miniaudio.h": "vendor/miniaudio/miniaudio.h",
"https://github.com/mackron/miniaudio/raw/13d161bc8d856ad61ae46b798bbeffc0f49808e8/miniaudio.h": "vendor/miniaudio/miniaudio.h",
"https://github.com/mackron/miniaudio/raw/9634bedb5b5a2ca38c1ee7108a9358a4e233f14d/miniaudio.h": "vendor/miniaudio/miniaudio.h",
f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/httplib.h": "httplib.h",
f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/split.py": "split.py",

View File

@ -230,11 +230,14 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
{ LLM_KV_ATTENTION_KEY_LENGTH_SWA, "%s.attention.key_length_swa" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_SWA, "%s.attention.value_length_swa" },
{ LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" },
{ LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" },
{ LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
{ LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" },
@ -1084,6 +1087,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_CLS_OUT,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_Q_NORM,

View File

@ -234,11 +234,14 @@ enum llm_kv {
LLM_KV_ATTENTION_TEMPERATURE_SCALE,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
LLM_KV_ATTENTION_KEY_LENGTH_SWA,
LLM_KV_ATTENTION_VALUE_LENGTH_SWA,
LLM_KV_ATTENTION_INDEXER_HEAD_COUNT,
LLM_KV_ATTENTION_INDEXER_KEY_LENGTH,
LLM_KV_ATTENTION_INDEXER_TOP_K,
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_COUNT_SWA,
LLM_KV_ROPE_DIMENSION_SECTIONS,
LLM_KV_ROPE_FREQ_BASE,
LLM_KV_ROPE_FREQ_BASE_SWA,

View File

@ -2876,19 +2876,23 @@ llama_context * llama_init_from_model(
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
const uint32_t blck_size = ggml_blck_size(params.type_k);
if (model->hparams.n_embd_head_k % blck_size != 0) {
LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
__func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k);
return nullptr;
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
if (model->hparams.n_embd_head_k(il) % blck_size != 0) {
LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
__func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k(il));
return nullptr;
}
}
}
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
const uint32_t blck_size = ggml_blck_size(params.type_v);
if (model->hparams.n_embd_head_v % blck_size != 0) {
LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n",
__func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v);
return nullptr;
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
if (model->hparams.n_embd_head_v(il) % blck_size != 0) {
LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_v=%u\n",
__func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v(il));
return nullptr;
}
}
}

View File

@ -601,7 +601,7 @@ const char * llama_grammar_parser::parse_sequence(
throw std::runtime_error(std::string("expecting an int at ") + pos);
}
const char * int_end = parse_int(pos);
uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
uint64_t min_times = std::stoull(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
uint64_t max_times = UINT64_MAX; // default: no max limit
@ -614,7 +614,7 @@ const char * llama_grammar_parser::parse_sequence(
if (is_digit_char(*pos)) {
const char * int_end = parse_int(pos);
max_times = std::stoul(std::string(pos, int_end - pos));
max_times = std::stoull(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
}

View File

@ -250,7 +250,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
const bool last = (
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL)) // qwen3 reranking & embedding models use last token
);
for (int i = 0; i < n_tokens; ++i) {
@ -849,13 +849,13 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
ubatch (params.ubatch),
n_embd (hparams.n_embd),
n_layer (hparams.n_layer),
n_rot (hparams.n_rot),
n_rot (hparams.n_rot()),
n_ctx (cparams.n_ctx),
n_head (hparams.n_head()),
n_head_kv (hparams.n_head_kv()),
n_embd_head_k (hparams.n_embd_head_k),
n_embd_head_k (hparams.n_embd_head_k()),
n_embd_k_gqa (hparams.n_embd_k_gqa()),
n_embd_head_v (hparams.n_embd_head_v),
n_embd_head_v (hparams.n_embd_head_v()),
n_embd_v_gqa (hparams.n_embd_v_gqa()),
n_expert (hparams.n_expert),
n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
@ -2552,7 +2552,7 @@ void llm_graph_context::build_pooling(
}
// softmax for qwen3 reranker
if (arch == LLM_ARCH_QWEN3) {
if (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL) {
cur = ggml_soft_max(ctx0, cur);
}
} break;

View File

@ -62,6 +62,14 @@ uint32_t llama_hparams::n_gqa(uint32_t il) const {
return n_head/n_head_kv;
}
uint32_t llama_hparams::n_rot(uint32_t il) const {
if (il < n_layer) {
return is_swa(il) ? n_rot_swa : n_rot_full;
}
GGML_ABORT("fatal error");
}
uint32_t llama_hparams::n_embd_inp() const {
uint32_t n_embd_inp = n_embd;
@ -76,16 +84,32 @@ uint32_t llama_hparams::n_embd_out() const {
return n_embd_out_impl > 0 ? n_embd_out_impl : n_embd;
}
uint32_t llama_hparams::n_embd_head_k(uint32_t il) const {
if (il < n_layer) {
return is_swa(il) ? n_embd_head_k_swa : n_embd_head_k_full;
}
GGML_ABORT("fatal error");
}
uint32_t llama_hparams::n_embd_head_v(uint32_t il) const {
if (il < n_layer) {
return is_swa(il) ? n_embd_head_v_swa : n_embd_head_v_full;
}
GGML_ABORT("fatal error");
}
uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
const uint32_t n_head_kv = this->n_head_kv(il);
return n_embd_head_k * n_head_kv;
return n_embd_head_k(il) * n_head_kv;
}
uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
const uint32_t n_head_kv = this->n_head_kv(il);
return n_embd_head_v * n_head_kv;
return n_embd_head_v(il) * n_head_kv;
}
bool llama_hparams::is_n_embd_k_gqa_variable() const {
@ -197,11 +221,11 @@ bool llama_hparams::is_mla() const {
}
uint32_t llama_hparams::n_embd_head_k_mla() const {
return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k;
return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k();
}
uint32_t llama_hparams::n_embd_head_v_mla() const {
return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v;
return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v();
}
bool llama_hparams::has_kv(uint32_t il) const {

View File

@ -44,13 +44,20 @@ struct llama_hparams {
uint32_t n_embd;
uint32_t n_layer;
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
uint32_t n_rot;
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
uint32_t n_expert = 0;
uint32_t n_expert_used = 0;
uint32_t n_rel_attn_bkts = 0;
// different head size for full_attention and SWA layers
uint32_t n_embd_head_k_full; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v_full; // dimension of values (d_v) aka n_embd_head
uint32_t n_embd_head_k_swa;
uint32_t n_embd_head_v_swa;
// different RoPE dimensions for full_attention and SWA layers
uint32_t n_rot_full;
uint32_t n_rot_swa;
// note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
uint32_t n_embd_head_k_mla_impl = 0;
uint32_t n_embd_head_v_mla_impl = 0;
@ -247,12 +254,18 @@ struct llama_hparams {
uint32_t n_gqa(uint32_t il = 0) const;
uint32_t n_rot(uint32_t il = 0) const;
// dimension of main + auxiliary input embeddings
uint32_t n_embd_inp() const;
// dimension of output embeddings
uint32_t n_embd_out() const;
// dimension of key/value embeddings for each head (per layer)
uint32_t n_embd_head_k(uint32_t il = 0) const;
uint32_t n_embd_head_v(uint32_t il = 0) const;
// dimension of key embeddings across all k-v heads
uint32_t n_embd_k_gqa(uint32_t il = 0) const;

View File

@ -1033,8 +1033,8 @@ ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_k
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
return ggml_view_4d(ctx, k,
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
ggml_row_size(k->type, hparams.n_embd_head_k),
hparams.n_embd_head_k(il), hparams.n_head_kv(il), n_kv, ns,
ggml_row_size(k->type, hparams.n_embd_head_k(il)),
ggml_row_size(k->type, n_embd_k_gqa),
ggml_row_size(k->type, n_embd_k_gqa*kv_size),
ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
@ -1056,8 +1056,8 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
if (!v_trans) {
// note: v->nb[1] <= v->nb[2]
return ggml_view_4d(ctx, v,
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
hparams.n_embd_head_v(il), hparams.n_head_kv(il), n_kv, ns,
ggml_row_size(v->type, hparams.n_embd_head_v(il)), // v->nb[1]
ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
@ -1065,8 +1065,8 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
// note: v->nb[1] > v->nb[2]
return ggml_view_4d(ctx, v,
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v(il), ns,
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v(il)), // v->nb[1]
ggml_row_size(v->type, kv_size), // v->nb[2]
ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
@ -1544,7 +1544,8 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale) const {
float freq_scale,
uint32_t il) const {
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
@ -1552,7 +1553,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
const auto & n_rot = hparams.n_rot;
const auto & n_rot = hparams.n_rot(il);
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
// @ngxson : this is a workaround
// for M-RoPE, we want to rotate the whole vector when doing KV shift
@ -1606,13 +1607,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
auto * ctx = res->get_ctx();
auto * gf = res->get_gf();
const auto & n_embd_head_k = hparams.n_embd_head_k;
//const auto & n_embd_head_v = hparams.n_embd_head_v;
const auto & n_rot = hparams.n_rot;
const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0;
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
@ -1626,6 +1620,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const auto n_rot = hparams.n_rot(il);
const auto n_embd_head_k = hparams.n_embd_head_k(il);
const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0;
const float freq_base_l = model.get_rope_freq_base (cparams, il);
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
@ -1638,7 +1636,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
ggml_row_size(layer.k->type, n_embd_k_gqa),
ggml_row_size(layer.k->type, n_embd_nope));
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il);
ggml_build_forward_expand(gf, cur);
}

View File

@ -264,7 +264,8 @@ private:
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale) const;
float freq_scale,
uint32_t il) const;
ggml_cgraph * build_graph_shift(
llm_graph_result * res,

View File

@ -918,7 +918,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
} break;
case GGML_OP_ROPE:
{
const int n_embd_head = hparams.n_embd_head_v;
const int n_embd_head = hparams.n_embd_head_v();
const int n_head = hparams.n_head();
ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512);
ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512);

View File

@ -186,8 +186,10 @@ void llama_model_saver::add_kv_from_model() {
add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true);
add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
add_kv(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv);
add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k);
add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v);
add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full);
add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full);
add_kv(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa);
add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa);
add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
add_kv(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
@ -199,7 +201,8 @@ void llama_model_saver::add_kv_from_model() {
const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train;
add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot);
add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot_full);
add_kv(LLM_KV_ROPE_DIMENSION_COUNT_SWA, hparams.n_rot_swa);
add_kv(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train);
// add_kv(LLM_KV_ROPE_SCALE_LINEAR, rope_scaling_factor); // old name
add_kv(LLM_KV_ROPE_SCALING_TYPE, llama_rope_scaling_type_name(hparams.rope_scaling_type_train));

View File

@ -459,26 +459,37 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// gpt-neox n_rot = rotary_pct * (n_embd / n_head)
// gpt-j n_rot = rotary_dim
hparams.n_embd_head_k = hparams.n_embd / hparams.n_head();
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
hparams.n_embd_head_k_full = hparams.n_embd / hparams.n_head();
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full, false);
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
hparams.n_embd_head_v_full = hparams.n_embd / hparams.n_head();
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full, false);
// sanity check for n_rot (optional)
hparams.n_rot = hparams.n_embd_head_k;
hparams.n_rot_full = hparams.n_embd_head_k_full;
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot_full, false);
if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON || arch == LLM_ARCH_LLAMA_EMBED) {
if (hparams.n_rot != hparams.n_embd_head_k) {
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
if (hparams.n_rot_full != hparams.n_embd_head_k_full) {
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot_full, hparams.n_embd_head_k_full));
}
}
} else {
hparams.n_rot = 0;
hparams.n_embd_head_k = 0;
hparams.n_embd_head_v = 0;
hparams.n_rot_full = 0;
hparams.n_embd_head_k_full = 0;
hparams.n_embd_head_v_full = 0;
}
// head size and n_rot for SWA layers
{
hparams.n_embd_head_k_swa = hparams.n_embd_head_k_full;
hparams.n_embd_head_v_swa = hparams.n_embd_head_v_full;
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa, false);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa, false);
hparams.n_rot_swa = hparams.n_rot_full;
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT_SWA, hparams.n_rot_swa, false);
}
// for differentiating model types
@ -1114,10 +1125,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
break;
default: type = LLM_TYPE_UNKNOWN;
}
// Load attention parameters
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
} break;
case LLM_ARCH_PLAMO3:
{
@ -1212,7 +1219,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173
hparams.f_attention_scale = type == LLM_TYPE_27B
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
: 1.0f / std::sqrt(float(hparams.n_embd_head_k()));
} break;
case LLM_ARCH_GEMMA3:
{
@ -1245,7 +1252,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289
hparams.f_attention_scale = type == LLM_TYPE_27B
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
: 1.0f / std::sqrt(float(hparams.n_embd_head_k()));
} break;
case LLM_ARCH_GEMMA3N:
{
@ -1294,7 +1301,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case 24: type = LLM_TYPE_0_3B; break;
default: type = LLM_TYPE_UNKNOWN;
}
hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k));
hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k()));
} break;
case LLM_ARCH_STARCODER2:
@ -2487,7 +2494,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl);
ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot);
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda);
@ -2518,6 +2524,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
// full_attention layer only use half of the RoPE dimensions
hparams.n_rot_full = hparams.n_rot_full / 2;
// MoE + SWA parameters
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
@ -2661,13 +2670,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const int64_t n_embd = hparams.n_embd;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_head_v = hparams.n_embd_head_v;
const int64_t n_embd_head_k = hparams.n_embd_head_k();
const int64_t n_embd_head_v = hparams.n_embd_head_v();
const int64_t n_ff = hparams.n_ff();
const int64_t n_embd_gqa = n_embd_v_gqa;
const int64_t n_vocab = vocab.n_tokens();
const int64_t n_token_types = vocab.n_token_types();
const int64_t n_rot = hparams.n_rot;
const int64_t n_rot = hparams.n_rot();
const int64_t n_expert = hparams.n_expert;
const int64_t n_expert_used = hparams.n_expert_used;
const int64_t n_ctx_train = hparams.n_ctx_train;
@ -2967,8 +2976,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} break;
case LLM_ARCH_MINICPM3:
{
const int64_t n_embd_head_qk_rope = hparams.n_rot;
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
const int64_t n_embd_head_qk_rope = hparams.n_rot();
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot();
const int64_t q_lora_rank = hparams.n_lora_q;
const int64_t kv_lora_rank = hparams.n_lora_kv;
@ -3840,8 +3849,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
// attention parameters
const uint32_t qk_dim = hparams.n_embd_head_k;
const uint32_t v_dim = hparams.n_embd_head_v;
const uint32_t qk_dim = hparams.n_embd_head_k();
const uint32_t v_dim = hparams.n_embd_head_v();
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -3901,8 +3910,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} break;
case LLM_ARCH_PLAMO3:
{
const int64_t head_dim_q = hparams.n_embd_head_k;
const int64_t head_dim_v = hparams.n_embd_head_v;
const int64_t head_dim_q = hparams.n_embd_head_k();
const int64_t head_dim_v = hparams.n_embd_head_v();
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -4649,7 +4658,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} break;
case LLM_ARCH_SEED_OSS:
{
const uint32_t head_dim = hparams.n_embd_head_k;
const uint32_t head_dim = hparams.n_embd_head_k();
const int64_t n_qo_dim = n_head * head_dim;
const int64_t n_kv_dim = n_head_kv * head_dim;
@ -4878,7 +4887,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
const int64_t n_embd_head_qk_rope = hparams.n_rot;
const int64_t n_embd_head_qk_rope = hparams.n_rot();
const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;
GGML_ASSERT(n_embd_head_qk_nope >= 1);
@ -4957,8 +4966,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} break;
case LLM_ARCH_PLM:
{
const int64_t n_embd_head_qk_rope = hparams.n_rot;
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
const int64_t n_embd_head_qk_rope = hparams.n_rot();
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot();
const int64_t kv_lora_rank = hparams.n_lora_kv;
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -5396,7 +5405,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
const int64_t n_embd_head_qk_rope = hparams.n_rot;
const int64_t n_embd_head_qk_rope = hparams.n_rot();
const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;
const int64_t q_lora_rank = hparams.n_lora_q;
@ -5680,7 +5689,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const int64_t n_expert = hparams.n_expert;
const int64_t n_expert_used = hparams.n_expert_used;
const int64_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : n_ff_exp;
const int64_t head_dim = hparams.n_embd_head_k;
const int64_t head_dim = hparams.n_embd_head_k();
const int64_t n_qo_dim = n_head * head_dim;
const int64_t n_kv_dim = n_head_kv * head_dim;
@ -6968,7 +6977,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// Kimi: qk_rope_head_dim = 64 (actual RoPE dimension for MLA)
// Note: hparams.n_rot may be 72 (from conversion) but actual is 64
const int64_t qk_rope_head_dim = hparams.n_rot; // From config: qk_rope_head_dim
const int64_t qk_rope_head_dim = hparams.n_rot(); // From config: qk_rope_head_dim
layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0);
// Support Legacy GGUFs that don't split wkv_b (MLA KV cache disabled)
layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i),
@ -7339,7 +7348,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer.
uint32_t n_rot_max = 0;
for (int i = 0; i < n_layer; ++i) {
n_rot_max = std::max(n_rot_max, hparams.n_rot);
n_rot_max = std::max(n_rot_max, hparams.n_rot(i));
}
if (n_rot_max == 0) {
n_rot_max = n_rot;
@ -7674,11 +7683,11 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str());
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot_full);
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any());
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k_full);
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v_full);
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str());
LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str());
@ -7702,6 +7711,9 @@ void llama_model::print_info() const {
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa);
LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa);
LLAMA_LOG_INFO("%s: n_embd_head_k_swa = %u\n", __func__, hparams.n_embd_head_k_swa);
LLAMA_LOG_INFO("%s: n_embd_head_v_swa = %u\n", __func__, hparams.n_embd_head_v_swa);
LLAMA_LOG_INFO("%s: n_rot_swa = %u\n", __func__, hparams.n_rot_swa);
}
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul);

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,8 @@
#include "models.h"
llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -3,10 +3,10 @@
llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -2,10 +2,10 @@
llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -1,10 +1,10 @@
#include "models.h"
llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -2,10 +2,10 @@
llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -2,10 +2,10 @@
llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -1,10 +1,10 @@
#include "models.h"
llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -2,9 +2,9 @@
llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -1,10 +1,10 @@
#include "models.h"
llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -3,10 +3,10 @@
#include <float.h>
llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -2,10 +2,10 @@
llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -1,11 +1,11 @@
#include "models.h"
llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -2,11 +2,11 @@
llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * inpL;
ggml_tensor * cur;

View File

@ -1,9 +1,9 @@
#include "models.h"
llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
const float f_logit_scale = hparams.f_logit_scale;

View File

@ -4,9 +4,9 @@
llm_build_command_r::llm_build_command_r(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
const float f_logit_scale = hparams.f_logit_scale;

View File

@ -1,11 +1,11 @@
#include "models.h"
llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -3,10 +3,10 @@
llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -2,10 +2,10 @@
llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -8,7 +8,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
const int64_t n_embd_head_k = hparams.n_embd_head_k_mla();
const int64_t n_embd_head_v = hparams.n_embd_head_v_mla();
const int64_t n_embd_head_qk_rope = hparams.n_rot;
const int64_t n_embd_head_qk_rope = hparams.n_rot();
const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope;
const uint32_t kv_lora_rank = hparams.n_lora_kv;

View File

@ -2,10 +2,10 @@
llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -5,10 +5,10 @@
llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
//copied from qwen2
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -2,10 +2,10 @@
llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -2,10 +2,10 @@
llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -1,9 +1,9 @@
#include "models.h"
llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -2,10 +2,10 @@
llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k;
const int64_t n_embd_head = hparams.n_embd_head_k();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -4,10 +4,10 @@
llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -4,10 +4,10 @@
template <bool iswa>
llm_build_exaone4<iswa>::llm_build_exaone4(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k;
const int64_t n_embd_head = hparams.n_embd_head_k();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -2,7 +2,7 @@
llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) :
llm_build_mamba_base(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -2,11 +2,11 @@
llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -2,7 +2,7 @@
llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k;
const int64_t n_embd_head = hparams.n_embd_head_k();
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -2,7 +2,7 @@
llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -1,7 +1,7 @@
#include "models.h"
llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k;
const int64_t n_embd_head = hparams.n_embd_head_k();
ggml_tensor * cur;
ggml_tensor * inpL;

View File

@ -2,7 +2,7 @@
template <bool iswa>
llm_build_gemma3<iswa>::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k;
const int64_t n_embd_head = hparams.n_embd_head_k();
ggml_tensor * cur;
ggml_tensor * inpL;

Some files were not shown because too many files have changed in this diff Show More