Merge branch 'master' into c_api
This commit is contained in:
commit
d3f9d29589
|
|
@ -469,6 +469,7 @@ jobs:
|
|||
cd build
|
||||
export GGML_VK_VISIBLE_DEVICES=0
|
||||
export GGML_VK_DISABLE_F16=1
|
||||
export GGML_VK_DISABLE_COOPMAT=1
|
||||
# This is using llvmpipe and runs slower than other backends
|
||||
ctest -L main --verbose --timeout 4800
|
||||
|
||||
|
|
|
|||
|
|
@ -81,6 +81,8 @@ add_library(${TARGET} STATIC
|
|||
preset.cpp
|
||||
preset.h
|
||||
regex-partial.cpp
|
||||
reasoning-budget.cpp
|
||||
reasoning-budget.h
|
||||
regex-partial.h
|
||||
sampling.cpp
|
||||
sampling.h
|
||||
|
|
|
|||
|
|
@ -2427,11 +2427,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
);
|
||||
}
|
||||
if (split_arg.size() == 1) {
|
||||
std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), std::stoul(split_arg[0]) * 1024*1024);
|
||||
std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), std::stoull(split_arg[0]) * 1024*1024);
|
||||
return;
|
||||
}
|
||||
for (size_t i = 0; i < split_arg.size(); i++) {
|
||||
params.fit_params_target[i] = std::stoul(split_arg[i]) * 1024*1024;
|
||||
params.fit_params_target[i] = std::stoull(split_arg[i]) * 1024*1024;
|
||||
}
|
||||
}
|
||||
).set_env("LLAMA_ARG_FIT_TARGET"));
|
||||
|
|
@ -2913,6 +2913,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params & params, const std::string & value) {
|
||||
auto parsed = json::parse(value);
|
||||
for (const auto & item : parsed.items()) {
|
||||
if (item.key() == "enable_thinking") {
|
||||
LOG_WRN("Setting 'enable_thinking' via --chat-template-kwargs is deprecated. "
|
||||
"Use --reasoning on / --reasoning off instead.\n");
|
||||
}
|
||||
params.default_template_kwargs[item.key()] = item.value().dump();
|
||||
}
|
||||
}
|
||||
|
|
@ -3048,14 +3052,39 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.reasoning_format = common_reasoning_format_from_name(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK"));
|
||||
add_opt(common_arg(
|
||||
{"-rea", "--reasoning"}, "[on|off|auto]",
|
||||
"Use reasoning/thinking in the chat ('on', 'off', or 'auto', default: 'auto' (detect from template))",
|
||||
[](common_params & params, const std::string & value) {
|
||||
if (is_truthy(value)) {
|
||||
params.enable_reasoning = 1;
|
||||
params.default_template_kwargs["enable_thinking"] = "true";
|
||||
} else if (is_falsey(value)) {
|
||||
params.enable_reasoning = 0;
|
||||
params.default_template_kwargs["enable_thinking"] = "false";
|
||||
} else if (is_autoy(value)) {
|
||||
params.enable_reasoning = -1;
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
string_format("error: unknown value for --reasoning: '%s'\n", value.c_str()));
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_REASONING"));
|
||||
add_opt(common_arg(
|
||||
{"--reasoning-budget"}, "N",
|
||||
"controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)",
|
||||
"token budget for thinking: -1 for unrestricted, 0 for immediate end, N>0 for token budget (default: -1)",
|
||||
[](common_params & params, int value) {
|
||||
if (value != 0 && value != -1) { throw std::invalid_argument("invalid value"); }
|
||||
if (value < -1) { throw std::invalid_argument("invalid value"); }
|
||||
params.reasoning_budget = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET"));
|
||||
add_opt(common_arg(
|
||||
{"--reasoning-budget-message"}, "MESSAGE",
|
||||
"message injected before the end-of-thinking tag when reasoning budget is exhausted (default: none)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.reasoning_budget_message = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
|
||||
add_opt(common_arg(
|
||||
{"--chat-template"}, "JINJA_TEMPLATE",
|
||||
string_format(
|
||||
|
|
|
|||
|
|
@ -135,7 +135,9 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co
|
|||
if (thinking_forced_open || thinking_forced_closed) {
|
||||
// Thinking is forced open OR forced closed with enable_thinking=true
|
||||
// In both cases, expect only the closing tag (opening was in template)
|
||||
return p.reasoning(p.until(end)) + end;
|
||||
// However, since we might have incorrectly detected the open/close pattern,
|
||||
// we admit an optional starting marker
|
||||
return p.optional(p.literal(start)) + p.reasoning(p.until(end)) + end;
|
||||
}
|
||||
if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) {
|
||||
// Standard tag-based reasoning OR tools-only mode (reasoning appears with tools)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
using ordered_json = nlohmann::ordered_json;
|
||||
|
||||
static std::string_view trim_trailing_space(std::string_view sv, int max = -1) {
|
||||
int count = 0;
|
||||
|
|
@ -68,7 +68,7 @@ static int json_brace_depth(const std::string & s) {
|
|||
|
||||
// JSON-escape a string and return the inner content (without surrounding quotes).
|
||||
static std::string escape_json_string_inner(const std::string & s) {
|
||||
std::string escaped = json(s).dump();
|
||||
std::string escaped = ordered_json(s).dump();
|
||||
if (escaped.size() >= 2 && escaped.front() == '"' && escaped.back() == '"') {
|
||||
return escaped.substr(1, escaped.size() - 2);
|
||||
}
|
||||
|
|
@ -309,7 +309,7 @@ void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
|
|||
if (arg_count > 0) {
|
||||
arg_entry = ",";
|
||||
}
|
||||
arg_entry += json(trim(node.text)).dump() + ":";
|
||||
arg_entry += ordered_json(trim(node.text)).dump() + ":";
|
||||
++arg_count;
|
||||
|
||||
auto & target = args_target();
|
||||
|
|
@ -343,7 +343,7 @@ void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
|
|||
|
||||
// Try to parse as JSON value (number, bool, null, object, array)
|
||||
try {
|
||||
json parsed = json::parse(value_content);
|
||||
ordered_json parsed = ordered_json::parse(value_content);
|
||||
if (parsed.is_string()) {
|
||||
// Don't add closing quote yet (added by arg_close) for monotonic streaming
|
||||
std::string escaped = parsed.dump();
|
||||
|
|
@ -408,7 +408,7 @@ void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
|
|||
|
||||
common_peg_parser common_chat_peg_builder::standard_constructed_tools(
|
||||
const std::map<std::string, std::string> & markers,
|
||||
const nlohmann::json & tools,
|
||||
const ordered_json & tools,
|
||||
bool parallel_tool_calls,
|
||||
bool force_tool_calls) {
|
||||
if (!tools.is_array() || tools.empty()) {
|
||||
|
|
@ -439,7 +439,7 @@ common_peg_parser common_chat_peg_builder::standard_constructed_tools(
|
|||
}
|
||||
const auto & function = tool_def.at("function");
|
||||
std::string name = function.at("name");
|
||||
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
|
||||
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
|
||||
|
||||
// Build argument parsers
|
||||
auto args = eps();
|
||||
|
|
@ -479,8 +479,8 @@ common_peg_parser common_chat_peg_builder::standard_constructed_tools(
|
|||
// Python-style tool calls: name(arg1="value1", arg2=123)
|
||||
// Used only by LFM2 for now, so we don't merge it into autoparser
|
||||
common_peg_parser common_chat_peg_builder::python_style_tool_calls(
|
||||
const nlohmann::json & tools,
|
||||
bool parallel_tool_calls) {
|
||||
const ordered_json & tools,
|
||||
bool parallel_tool_calls) {
|
||||
if (!tools.is_array() || tools.empty()) {
|
||||
return eps();
|
||||
}
|
||||
|
|
@ -493,7 +493,7 @@ common_peg_parser common_chat_peg_builder::python_style_tool_calls(
|
|||
}
|
||||
const auto & function = tool_def.at("function");
|
||||
std::string name = function.at("name");
|
||||
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
|
||||
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
|
||||
|
||||
auto args = eps();
|
||||
if (params.contains("properties") && !params["properties"].empty()) {
|
||||
|
|
@ -555,11 +555,11 @@ static std::pair<std::string, std::string> parse_key_spec(const std::string & ke
|
|||
|
||||
// Mode 1: function_is_key — parse {"function_name": {...}}
|
||||
common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key(
|
||||
const nlohmann::json & tools,
|
||||
const std::string & args_key,
|
||||
const std::string & effective_args_key,
|
||||
const std::string & call_id_key,
|
||||
const std::string & gen_call_id_key) {
|
||||
const ordered_json & tools,
|
||||
const std::string & args_key,
|
||||
const std::string & effective_args_key,
|
||||
const std::string & call_id_key,
|
||||
const std::string & gen_call_id_key) {
|
||||
|
||||
auto tool_choices = choice();
|
||||
|
||||
|
|
@ -569,7 +569,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key(
|
|||
}
|
||||
const auto & function = tool_def.at("function");
|
||||
std::string name = function.at("name");
|
||||
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
|
||||
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
|
||||
|
||||
// Build inner object fields
|
||||
std::vector<common_peg_parser> inner_fields;
|
||||
|
|
@ -634,11 +634,11 @@ common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key(
|
|||
|
||||
// Mode 2: Nested keys (dot notation like "function.name")
|
||||
common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
|
||||
const nlohmann::json & tools,
|
||||
const std::string & effective_name_key,
|
||||
const std::string & effective_args_key,
|
||||
const std::string & call_id_key,
|
||||
const std::string & gen_call_id_key) {
|
||||
const ordered_json & tools,
|
||||
const std::string & effective_name_key,
|
||||
const std::string & effective_args_key,
|
||||
const std::string & call_id_key,
|
||||
const std::string & gen_call_id_key) {
|
||||
|
||||
auto tool_choices = choice();
|
||||
|
||||
|
|
@ -655,7 +655,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
|
|||
}
|
||||
const auto & function = tool_def.at("function");
|
||||
std::string name = function.at("name");
|
||||
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
|
||||
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
|
||||
|
||||
auto nested_name = literal("\"" + nested_name_field + "\"") + space() + literal(":") + space() +
|
||||
literal("\"") + tool_name(literal(name)) + literal("\"");
|
||||
|
|
@ -706,7 +706,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
|
|||
|
||||
// Mode 3: Flat keys with optional ID fields and parameter ordering
|
||||
common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
|
||||
const nlohmann::json & tools,
|
||||
const ordered_json & tools,
|
||||
const std::string & effective_name_key,
|
||||
const std::string & effective_args_key,
|
||||
const std::string & call_id_key,
|
||||
|
|
@ -723,7 +723,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
|
|||
}
|
||||
const auto & function = tool_def.at("function");
|
||||
std::string name = function.at("name");
|
||||
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
|
||||
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
|
||||
|
||||
auto tool_name_ = name_key_parser + space() + literal(":") + space() +
|
||||
literal("\"") + tool_name(literal(name)) + literal("\"");
|
||||
|
|
@ -791,7 +791,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
|
|||
common_peg_parser common_chat_peg_builder::standard_json_tools(
|
||||
const std::string & section_start,
|
||||
const std::string & section_end,
|
||||
const nlohmann::json & tools,
|
||||
const ordered_json & tools,
|
||||
bool parallel_tool_calls,
|
||||
bool force_tool_calls,
|
||||
const std::string & name_key,
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ class common_chat_peg_builder : public common_peg_parser_builder {
|
|||
// parameters_order: order in which JSON fields should be parsed
|
||||
common_peg_parser standard_json_tools(const std::string & section_start,
|
||||
const std::string & section_end,
|
||||
const nlohmann::json & tools,
|
||||
const nlohmann::ordered_json & tools,
|
||||
bool parallel_tool_calls,
|
||||
bool force_tool_calls,
|
||||
const std::string & name_key = "",
|
||||
|
|
@ -108,30 +108,30 @@ class common_chat_peg_builder : public common_peg_parser_builder {
|
|||
// Legacy-compatible helper for building XML/tagged style tool calls
|
||||
// Used by tests and manual parsers
|
||||
common_peg_parser standard_constructed_tools(const std::map<std::string, std::string> & markers,
|
||||
const nlohmann::json & tools,
|
||||
const nlohmann::ordered_json & tools,
|
||||
bool parallel_tool_calls,
|
||||
bool force_tool_calls);
|
||||
|
||||
// Helper for Python-style function call format: name(arg1="value1", arg2=123)
|
||||
// Used by LFM2 and similar templates
|
||||
common_peg_parser python_style_tool_calls(const nlohmann::json & tools,
|
||||
bool parallel_tool_calls);
|
||||
common_peg_parser python_style_tool_calls(const nlohmann::ordered_json & tools,
|
||||
bool parallel_tool_calls);
|
||||
|
||||
private:
|
||||
// Implementation helpers for standard_json_tools — one per JSON tool call layout mode
|
||||
common_peg_parser build_json_tools_function_is_key(const nlohmann::json & tools,
|
||||
const std::string & args_key,
|
||||
const std::string & effective_args_key,
|
||||
const std::string & call_id_key,
|
||||
const std::string & gen_call_id_key);
|
||||
common_peg_parser build_json_tools_function_is_key(const nlohmann::ordered_json & tools,
|
||||
const std::string & args_key,
|
||||
const std::string & effective_args_key,
|
||||
const std::string & call_id_key,
|
||||
const std::string & gen_call_id_key);
|
||||
|
||||
common_peg_parser build_json_tools_nested_keys(const nlohmann::json & tools,
|
||||
const std::string & effective_name_key,
|
||||
const std::string & effective_args_key,
|
||||
const std::string & call_id_key,
|
||||
const std::string & gen_call_id_key);
|
||||
common_peg_parser build_json_tools_nested_keys(const nlohmann::ordered_json & tools,
|
||||
const std::string & effective_name_key,
|
||||
const std::string & effective_args_key,
|
||||
const std::string & call_id_key,
|
||||
const std::string & gen_call_id_key);
|
||||
|
||||
common_peg_parser build_json_tools_flat_keys(const nlohmann::json & tools,
|
||||
common_peg_parser build_json_tools_flat_keys(const nlohmann::ordered_json & tools,
|
||||
const std::string & effective_name_key,
|
||||
const std::string & effective_args_key,
|
||||
const std::string & call_id_key,
|
||||
|
|
|
|||
|
|
@ -857,7 +857,9 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
|||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = true;
|
||||
|
||||
data.supports_thinking = true;
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "[THINK]";
|
||||
data.thinking_end_tag = "[/THINK]";
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
|
|
@ -1165,9 +1167,11 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
|||
const autoparser::templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "<think>";
|
||||
data.thinking_end_tag = "</think>";
|
||||
data.preserved_tokens = {
|
||||
"<|tool_calls_section_begin|>",
|
||||
"<|tool_calls_section_end|>",
|
||||
|
|
@ -1527,6 +1531,16 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
|||
autoparser.analyze_template(tmpl);
|
||||
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
|
||||
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
|
||||
if (auto_params.supports_thinking) {
|
||||
auto_params.thinking_start_tag = autoparser.reasoning.start;
|
||||
auto_params.thinking_end_tag = autoparser.reasoning.end;
|
||||
// FORCED_OPEN and FORCED_CLOSED both put <think> in the generation prompt
|
||||
// (FORCED_CLOSED forces empty <think></think> when thinking is disabled,
|
||||
// but forces <think> open when thinking is enabled)
|
||||
auto_params.thinking_forced_open =
|
||||
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_OPEN ||
|
||||
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_CLOSED;
|
||||
}
|
||||
return auto_params;
|
||||
} catch (const std::exception & e) {
|
||||
throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what());
|
||||
|
|
@ -1620,8 +1634,8 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
|||
build_chat_peg_parser([](common_chat_peg_builder & p) { return p.content(p.rest()) + p.end(); }) :
|
||||
src_parser;
|
||||
|
||||
if (src_parser.empty()) {
|
||||
LOG_WRN("No parser definition detected, assuming pure content parser.");
|
||||
if (src_parser.empty()) {
|
||||
LOG_DBG("No parser definition detected, assuming pure content parser.");
|
||||
}
|
||||
|
||||
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), input.c_str());
|
||||
|
|
|
|||
|
|
@ -213,6 +213,8 @@ struct common_chat_params {
|
|||
bool grammar_lazy = false;
|
||||
bool thinking_forced_open = false;
|
||||
bool supports_thinking = false;
|
||||
std::string thinking_start_tag; // e.g., "<think>"
|
||||
std::string thinking_end_tag; // e.g., "</think>"
|
||||
std::vector<common_grammar_trigger> grammar_triggers;
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
|
|
|
|||
|
|
@ -235,6 +235,14 @@ struct common_params_sampling {
|
|||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||
|
||||
// reasoning budget sampler parameters
|
||||
// these are populated by the server/CLI based on chat template params
|
||||
int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget
|
||||
bool reasoning_budget_activate_immediately = false;
|
||||
std::vector<llama_token> reasoning_budget_start; // start tag token sequence
|
||||
std::vector<llama_token> reasoning_budget_end; // end tag token sequence
|
||||
std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag)
|
||||
|
||||
bool backend_sampling = false;
|
||||
|
||||
bool has_logit_bias() const {
|
||||
|
|
@ -536,7 +544,9 @@ struct common_params {
|
|||
bool use_jinja = true; // NOLINT
|
||||
bool enable_chat_template = true;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
int enable_reasoning = -1; // -1 = auto, 0 = disable, 1 = enable
|
||||
int reasoning_budget = -1;
|
||||
std::string reasoning_budget_message; // message injected before end tag when budget exhausted
|
||||
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
||||
int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time
|
||||
|
||||
|
|
|
|||
|
|
@ -790,7 +790,7 @@ public:
|
|||
} else if (target.is_array()) {
|
||||
size_t sel_index;
|
||||
try {
|
||||
sel_index = std::stoul(sel);
|
||||
sel_index = std::stoull(sel);
|
||||
} catch (const std::invalid_argument & e) {
|
||||
sel_index = target.size();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,219 @@
|
|||
#include "reasoning-budget.h"
|
||||
#include "common.h"
|
||||
#include "unicode.h"
|
||||
|
||||
#include "log.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
struct token_matcher {
|
||||
std::vector<llama_token> tokens;
|
||||
size_t pos = 0;
|
||||
|
||||
bool advance(llama_token token) {
|
||||
if (tokens.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (token == tokens[pos]) {
|
||||
pos++;
|
||||
if (pos >= tokens.size()) {
|
||||
pos = 0;
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
pos = 0;
|
||||
if (token == tokens[0]) {
|
||||
pos = 1;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void reset() { pos = 0; }
|
||||
};
|
||||
|
||||
struct common_reasoning_budget_ctx {
|
||||
const llama_vocab * vocab;
|
||||
|
||||
token_matcher start_matcher;
|
||||
token_matcher end_matcher;
|
||||
std::vector<llama_token> forced_tokens;
|
||||
|
||||
int32_t budget; // maximum tokens in reasoning block
|
||||
int32_t remaining; // tokens remaining in budget
|
||||
|
||||
common_reasoning_budget_state state;
|
||||
|
||||
// for forcing
|
||||
size_t force_pos; // next position in forced_tokens to force
|
||||
};
|
||||
|
||||
static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) {
|
||||
return "reasoning-budget";
|
||||
}
|
||||
|
||||
static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) {
|
||||
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
|
||||
|
||||
switch (ctx->state) {
|
||||
case REASONING_BUDGET_IDLE:
|
||||
{
|
||||
if (ctx->start_matcher.advance(token)) {
|
||||
ctx->state = REASONING_BUDGET_COUNTING;
|
||||
ctx->remaining = ctx->budget;
|
||||
LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget);
|
||||
|
||||
if (ctx->remaining <= 0) {
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case REASONING_BUDGET_COUNTING:
|
||||
case REASONING_BUDGET_WAITING_UTF8:
|
||||
{
|
||||
if (ctx->end_matcher.advance(token)) {
|
||||
ctx->state = REASONING_BUDGET_DONE;
|
||||
LOG_INF("reasoning-budget: deactivated (natural end)\n");
|
||||
break;
|
||||
}
|
||||
|
||||
bool utf8_complete = true;
|
||||
if (ctx->vocab != nullptr) {
|
||||
const std::string piece = common_token_to_piece(ctx->vocab, token, false);
|
||||
utf8_complete = common_utf8_is_complete(piece);
|
||||
}
|
||||
|
||||
if (ctx->state == REASONING_BUDGET_WAITING_UTF8) {
|
||||
if (utf8_complete) {
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
|
||||
}
|
||||
} else if (ctx->state == REASONING_BUDGET_COUNTING) {
|
||||
ctx->remaining--;
|
||||
if (ctx->remaining <= 0) {
|
||||
if (utf8_complete) {
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n");
|
||||
} else {
|
||||
ctx->state = REASONING_BUDGET_WAITING_UTF8;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case REASONING_BUDGET_FORCING:
|
||||
// force_pos is advanced in apply(), not here.
|
||||
// This ensures the first forced token isn't skipped when the sampler
|
||||
// is initialized directly in FORCING state (e.g. COUNTING + budget=0)
|
||||
break;
|
||||
case REASONING_BUDGET_DONE:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
|
||||
|
||||
if (ctx->state != REASONING_BUDGET_FORCING) {
|
||||
// passthrough — don't modify logits
|
||||
return;
|
||||
}
|
||||
|
||||
if (ctx->force_pos >= ctx->forced_tokens.size()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const llama_token forced = ctx->forced_tokens[ctx->force_pos];
|
||||
|
||||
// set all logits to -inf except the forced token
|
||||
for (size_t i = 0; i < cur_p->size; i++) {
|
||||
if (cur_p->data[i].id != forced) {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
// advance to next forced token (done here rather than in accept so that
|
||||
// the first forced token isn't skipped when starting in FORCING state)
|
||||
ctx->force_pos++;
|
||||
if (ctx->force_pos >= ctx->forced_tokens.size()) {
|
||||
ctx->state = REASONING_BUDGET_DONE;
|
||||
LOG_INF("reasoning-budget: forced sequence complete, done\n");
|
||||
}
|
||||
}
|
||||
|
||||
static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
|
||||
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
|
||||
ctx->state = REASONING_BUDGET_IDLE;
|
||||
ctx->remaining = ctx->budget;
|
||||
ctx->start_matcher.reset();
|
||||
ctx->end_matcher.reset();
|
||||
ctx->force_pos = 0;
|
||||
}
|
||||
|
||||
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
|
||||
return common_reasoning_budget_init(
|
||||
ctx->vocab,
|
||||
ctx->start_matcher.tokens,
|
||||
ctx->end_matcher.tokens,
|
||||
ctx->forced_tokens,
|
||||
ctx->budget,
|
||||
ctx->state);
|
||||
}
|
||||
|
||||
static void common_reasoning_budget_free(struct llama_sampler * smpl) {
|
||||
delete (common_reasoning_budget_ctx *) smpl->ctx;
|
||||
}
|
||||
|
||||
static struct llama_sampler_i common_reasoning_budget_i = {
|
||||
/* .name = */ common_reasoning_budget_name,
|
||||
/* .accept = */ common_reasoning_budget_accept,
|
||||
/* .apply = */ common_reasoning_budget_apply,
|
||||
/* .reset = */ common_reasoning_budget_reset,
|
||||
/* .clone = */ common_reasoning_budget_clone,
|
||||
/* .free = */ common_reasoning_budget_free,
|
||||
/* .backend_init = */ nullptr,
|
||||
/* .backend_accept = */ nullptr,
|
||||
/* .backend_apply = */ nullptr,
|
||||
/* .backend_set_input = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
common_reasoning_budget_state initial_state) {
|
||||
// promote COUNTING with budget <= 0 to FORCING
|
||||
if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
|
||||
initial_state = REASONING_BUDGET_FORCING;
|
||||
}
|
||||
|
||||
return llama_sampler_init(
|
||||
/* .iface = */ &common_reasoning_budget_i,
|
||||
/* .ctx = */ new common_reasoning_budget_ctx {
|
||||
/* .vocab = */ vocab,
|
||||
/* .start_matcher = */ { start_tokens, 0 },
|
||||
/* .end_matcher = */ { end_tokens, 0 },
|
||||
/* .forced_tokens = */ forced_tokens,
|
||||
/* .budget = */ budget,
|
||||
/* .remaining = */ budget,
|
||||
/* .state = */ initial_state,
|
||||
/* .force_pos = */ 0,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
enum common_reasoning_budget_state {
|
||||
REASONING_BUDGET_IDLE, // waiting for start sequence
|
||||
REASONING_BUDGET_COUNTING, // counting down tokens
|
||||
REASONING_BUDGET_FORCING, // forcing budget message + end sequence
|
||||
REASONING_BUDGET_WAITING_UTF8, // budget exhausted, waiting for UTF-8 completion
|
||||
REASONING_BUDGET_DONE, // passthrough forever
|
||||
};
|
||||
|
||||
// Creates a reasoning budget sampler that limits token generation inside a
|
||||
// reasoning block (e.g. between <think> and </think>).
|
||||
//
|
||||
// State machine: IDLE -> COUNTING -> WAITING_UTF8 -> FORCING -> DONE
|
||||
// IDLE: passthrough, watching for start_tokens sequence
|
||||
// COUNTING: counting down remaining tokens, watching for natural end_tokens
|
||||
// WAITING_UTF8: budget exhausted, allowing tokens to complete a UTF-8 sequence
|
||||
// FORCING: forces forced_tokens token-by-token (all other logits -> -inf)
|
||||
// DONE: passthrough forever
|
||||
//
|
||||
// Parameters:
|
||||
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
|
||||
// start_tokens - token sequence that activates counting
|
||||
// end_tokens - token sequence for natural deactivation
|
||||
// forced_tokens - token sequence forced when budget expires
|
||||
// budget - max tokens allowed in the reasoning block
|
||||
// initial_state - initial state of the sampler (e.g. IDLE or COUNTING)
|
||||
// note: COUNTING with budget <= 0 is promoted to FORCING
|
||||
//
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
common_reasoning_budget_state initial_state);
|
||||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "reasoning-budget.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
|
@ -250,6 +251,17 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
|||
}
|
||||
}
|
||||
|
||||
// reasoning budget sampler — added first so it can force tokens before other samplers
|
||||
if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
|
||||
samplers.push_back(common_reasoning_budget_init(
|
||||
vocab,
|
||||
params.reasoning_budget_start,
|
||||
params.reasoning_budget_end,
|
||||
params.reasoning_budget_forced,
|
||||
params.reasoning_budget_tokens,
|
||||
params.reasoning_budget_activate_immediately ? REASONING_BUDGET_COUNTING : REASONING_BUDGET_IDLE));
|
||||
}
|
||||
|
||||
if (params.has_logit_bias()) {
|
||||
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
#include "unicode.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// implementation adopted from src/unicode.cpp
|
||||
|
||||
|
|
@ -67,6 +69,20 @@ utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t off
|
|||
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||
}
|
||||
|
||||
bool common_utf8_is_complete(const std::string & s) {
|
||||
if (s.empty()) {
|
||||
return true;
|
||||
}
|
||||
for (int i = 1; i <= std::min(4, (int)s.size()); i++) {
|
||||
unsigned char c = s[s.size() - i];
|
||||
if ((c & 0xC0) != 0x80) {
|
||||
int expected = (c >= 0xF0) ? 4 : (c >= 0xE0) ? 3 : (c >= 0xC0) ? 2 : 1;
|
||||
return i >= expected;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string common_unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
|
||||
std::string result;
|
||||
for (size_t i = 0; i < cps.size(); ++i) {
|
||||
|
|
|
|||
|
|
@ -20,6 +20,9 @@ struct utf8_parse_result {
|
|||
// Returns 0 for invalid first bytes
|
||||
size_t common_utf8_sequence_length(unsigned char first_byte);
|
||||
|
||||
// Check if a string ends with a complete UTF-8 sequence.
|
||||
bool common_utf8_is_complete(const std::string & s);
|
||||
|
||||
// Parse a single UTF-8 codepoint from input
|
||||
utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t offset);
|
||||
|
||||
|
|
|
|||
|
|
@ -4390,15 +4390,31 @@ class Qwen3Model(Qwen2Model):
|
|||
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
|
||||
self.origin_hf_arch = hparams.get('architectures', [None])[0]
|
||||
|
||||
# a bit hacky, but currently the only way to detect if this is a rerank model
|
||||
# ref: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
|
||||
if self._is_qwen3_reranker():
|
||||
self._find_rerank_config()
|
||||
|
||||
def _is_qwen3_reranker(self) -> bool:
|
||||
readme_path = self.dir_model / "README.md"
|
||||
readme_text = ""
|
||||
if readme_path.exists():
|
||||
with readme_path.open("r", encoding="utf-8") as f:
|
||||
readme_text = f.read()
|
||||
if "# Qwen3-Reranker" in readme_text:
|
||||
self._find_rerank_config()
|
||||
|
||||
name_hints = [
|
||||
str(self.dir_model.name),
|
||||
str(self.hparams.get("_name_or_path", "")),
|
||||
str(self.hparams.get("model_type", "")),
|
||||
str(self.origin_hf_arch or ""),
|
||||
]
|
||||
name_hints = [hint.lower() for hint in name_hints if hint]
|
||||
|
||||
if "# qwen3-reranker" in readme_text.lower() or "# qwen3-vl-reranker" in readme_text.lower():
|
||||
return True
|
||||
|
||||
if any("qwen3-reranker" in hint or "qwen3-vl-reranker" in hint for hint in name_hints):
|
||||
return True
|
||||
|
||||
return "sequenceclassification" in (self.origin_hf_arch or "").lower()
|
||||
|
||||
def set_vocab(self):
|
||||
# deal with intern-s1-mini
|
||||
|
|
|
|||
|
|
@ -382,17 +382,27 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
|||
|
||||
## Windows
|
||||
|
||||
### I. Setup Environment
|
||||
|
||||
1. Install GPU driver
|
||||
### Install GPU driver
|
||||
|
||||
Intel GPU drivers instructions guide and download page can be found here: [Get Intel GPU Drivers](https://www.intel.com/content/www/us/en/products/docs/discrete-gpus/arc/software/drivers.html).
|
||||
|
||||
2. Install Visual Studio
|
||||
### Option 1: download the binary package directly
|
||||
|
||||
Download the binary package for Windows from: https://github.com/ggml-org/llama.cpp/releases.
|
||||
|
||||
Extract the package to local folder, run the llama tools directly. Refer to [Run the inference](#iii-run-the-inference-1).
|
||||
|
||||
Note, the package includes the SYCL running time and all depended dll files, no need to install oneAPI package and activte them.
|
||||
|
||||
### Option 2: build locally from the source code.
|
||||
|
||||
#### I. Setup environment
|
||||
|
||||
1. Install Visual Studio
|
||||
|
||||
If you already have a recent version of Microsoft Visual Studio, you can skip this step. Otherwise, please refer to the official download page for [Microsoft Visual Studio](https://visualstudio.microsoft.com/).
|
||||
|
||||
3. Install Intel® oneAPI Base toolkit
|
||||
2. Install Intel® oneAPI Base toolkit
|
||||
|
||||
SYCL backend depends on:
|
||||
- Intel® oneAPI DPC++/C++ compiler/running-time.
|
||||
|
|
@ -443,25 +453,25 @@ Output (example):
|
|||
[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Iris(R) Xe Graphics 1.3 [1.3.28044]
|
||||
```
|
||||
|
||||
4. Install build tools
|
||||
3. Install build tools
|
||||
|
||||
a. Download & install cmake for Windows: https://cmake.org/download/ (CMake can also be installed from Visual Studio Installer)
|
||||
b. The new Visual Studio will install Ninja as default. (If not, please install it manually: https://ninja-build.org/)
|
||||
|
||||
|
||||
### II. Build llama.cpp
|
||||
#### II. Build llama.cpp
|
||||
|
||||
You could download the release package for Windows directly, which including binary files and depended oneAPI dll files.
|
||||
|
||||
Choose one of following methods to build from source code.
|
||||
|
||||
#### 1. Script
|
||||
##### Option 1: Script
|
||||
|
||||
```sh
|
||||
.\examples\sycl\win-build-sycl.bat
|
||||
```
|
||||
|
||||
#### 2. CMake
|
||||
##### Option 2: CMake
|
||||
|
||||
On the oneAPI command line window, step into the llama.cpp main directory and run the following:
|
||||
|
||||
|
|
@ -490,7 +500,7 @@ cmake --preset x64-windows-sycl-debug
|
|||
cmake --build build-x64-windows-sycl-debug -j --target llama-completion
|
||||
```
|
||||
|
||||
#### 3. Visual Studio
|
||||
##### Option 3: Visual Studio
|
||||
|
||||
You have two options to use Visual Studio to build llama.cpp:
|
||||
- As CMake Project using CMake presets.
|
||||
|
|
@ -500,7 +510,7 @@ You have two options to use Visual Studio to build llama.cpp:
|
|||
|
||||
All following commands are executed in PowerShell.
|
||||
|
||||
##### - Open as a CMake Project
|
||||
###### - Open as a CMake Project
|
||||
|
||||
You can use Visual Studio to open the `llama.cpp` folder directly as a CMake project. Before compiling, select one of the SYCL CMake presets:
|
||||
|
||||
|
|
@ -515,7 +525,7 @@ You can use Visual Studio to open the `llama.cpp` folder directly as a CMake pro
|
|||
cmake --build build --config Release -j --target llama-completion
|
||||
```
|
||||
|
||||
##### - Generating a Visual Studio Solution
|
||||
###### - Generating a Visual Studio Solution
|
||||
|
||||
You can use Visual Studio solution to build and work on llama.cpp on Windows. You need to convert the CMake Project into a `.sln` file.
|
||||
|
||||
|
|
@ -603,7 +613,7 @@ found 2 SYCL devices:
|
|||
|
||||
```
|
||||
|
||||
#### Choose level-zero devices
|
||||
##### Choose level-zero devices
|
||||
|
||||
|Chosen Device ID|Setting|
|
||||
|-|-|
|
||||
|
|
@ -611,7 +621,7 @@ found 2 SYCL devices:
|
|||
|1|`set ONEAPI_DEVICE_SELECTOR="level_zero:1"`|
|
||||
|0 & 1|`set ONEAPI_DEVICE_SELECTOR="level_zero:0;level_zero:1"` or `set ONEAPI_DEVICE_SELECTOR="level_zero:*"`|
|
||||
|
||||
#### Execute
|
||||
##### Execute
|
||||
|
||||
Choose one of following methods to run.
|
||||
|
||||
|
|
@ -669,7 +679,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
|||
|
||||
## Environment Variable
|
||||
|
||||
#### Build
|
||||
### Build
|
||||
|
||||
| Name | Value | Function |
|
||||
|--------------------|---------------------------------------|---------------------------------------------|
|
||||
|
|
@ -684,7 +694,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
|||
|
||||
1. FP32 or FP16 have different performance impact to LLM. Recommended to test them for better prompt processing performance on your models. You need to rebuild the code after change `GGML_SYCL_F16=OFF/ON`.
|
||||
|
||||
#### Runtime
|
||||
### Runtime
|
||||
|
||||
| Name | Value | Function |
|
||||
|-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
|
||||
|
|
@ -777,7 +787,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
|||
```
|
||||
|
||||
### **GitHub contribution**:
|
||||
Please add the `SYCL :` prefix/tag in issues/PRs titles to help the SYCL contributors to check/address them without delay.
|
||||
Please add the `[SYCL]` prefix/tag in issues/PRs titles to help the SYCL contributors to check/address them without delay.
|
||||
|
||||
## TODO
|
||||
|
||||
|
|
|
|||
18
docs/ops.md
18
docs/ops.md
|
|
@ -23,7 +23,7 @@ Legend:
|
|||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
|
|
@ -31,7 +31,7 @@ Legend:
|
|||
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
|
@ -47,7 +47,7 @@ Legend:
|
|||
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GATED_DELTA_NET | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
|
|
@ -64,7 +64,7 @@ Legend:
|
|||
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
|
|
@ -76,7 +76,7 @@ Legend:
|
|||
| OUT_PROD | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | 🟡 |
|
||||
| PAD | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_1D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_1D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
|
|
@ -86,7 +86,7 @@ Legend:
|
|||
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ROPE | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
|
|
@ -97,13 +97,13 @@ Legend:
|
|||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
|
|
|
|||
8525
docs/ops/CPU.csv
8525
docs/ops/CPU.csv
File diff suppressed because it is too large
Load Diff
14129
docs/ops/SYCL.csv
14129
docs/ops/SYCL.csv
File diff suppressed because it is too large
Load Diff
|
|
@ -633,7 +633,7 @@ class SchemaConverter:
|
|||
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
|
||||
|
||||
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
|
||||
items = schema.get('items') or schema['prefixItems']
|
||||
items = schema.get('items', schema.get('prefixItems'))
|
||||
if isinstance(items, list):
|
||||
return self._add_rule(
|
||||
rule_name,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,12 @@ extern "C" {
|
|||
|
||||
#define RPC_PROTO_MAJOR_VERSION 3
|
||||
#define RPC_PROTO_MINOR_VERSION 6
|
||||
#define RPC_PROTO_PATCH_VERSION 0
|
||||
#define RPC_PROTO_PATCH_VERSION 1
|
||||
|
||||
#ifdef __cplusplus
|
||||
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION");
|
||||
#endif
|
||||
|
||||
#define GGML_RPC_MAX_SERVERS 16
|
||||
|
||||
// backend API
|
||||
|
|
|
|||
|
|
@ -2,28 +2,29 @@
|
|||
#include "ggml-cuda/common.cuh"
|
||||
|
||||
template <int S_v, bool KDA>
|
||||
__global__ void gated_delta_net_cuda(const float * q,
|
||||
const float * k,
|
||||
const float * v,
|
||||
const float * g,
|
||||
const float * beta,
|
||||
const float * curr_state,
|
||||
float * dst,
|
||||
int64_t H,
|
||||
int64_t n_tokens,
|
||||
int64_t n_seqs,
|
||||
int64_t sq1,
|
||||
int64_t sq2,
|
||||
int64_t sq3,
|
||||
int64_t sv1,
|
||||
int64_t sv2,
|
||||
int64_t sv3,
|
||||
int64_t sb1,
|
||||
int64_t sb2,
|
||||
int64_t sb3,
|
||||
int64_t rq1,
|
||||
int64_t rq3,
|
||||
float scale) {
|
||||
__global__ void __launch_bounds__(S_v, 1)
|
||||
gated_delta_net_cuda(const float * q,
|
||||
const float * k,
|
||||
const float * v,
|
||||
const float * g,
|
||||
const float * beta,
|
||||
const float * curr_state,
|
||||
float * dst,
|
||||
const int64_t H,
|
||||
const int64_t n_tokens,
|
||||
const int64_t n_seqs,
|
||||
const int64_t sq1,
|
||||
const int64_t sq2,
|
||||
const int64_t sq3,
|
||||
const int64_t sv1,
|
||||
const int64_t sv2,
|
||||
const int64_t sv3,
|
||||
const int64_t sb1,
|
||||
const int64_t sb2,
|
||||
const int64_t sb3,
|
||||
const int64_t rq1,
|
||||
const int64_t rq3,
|
||||
const float scale) {
|
||||
const int64_t h_idx = blockIdx.x;
|
||||
const int64_t sequence = blockIdx.y;
|
||||
const int col = threadIdx.x; // each thread owns one column
|
||||
|
|
@ -40,8 +41,14 @@ __global__ void gated_delta_net_cuda(const float * q,
|
|||
curr_state += state_offset;
|
||||
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
|
||||
|
||||
// Load state column into registers
|
||||
// GCN and CDNA devices spill registers, we use shared mem for them. See https://github.com/ggml-org/llama.cpp/pull/20282#issuecomment-4025770229
|
||||
// TODO: check optimal path for RDNA1 and RDNA2 devices.
|
||||
#if (defined(GGML_USE_HIP) && !defined(RDNA3) && !defined(RDNA4)) || defined(GGML_USE_MUSA)
|
||||
extern __shared__ float s_shared[];
|
||||
float * s = s_shared + col * S_v;
|
||||
#else
|
||||
float s[S_v];
|
||||
#endif
|
||||
#pragma unroll
|
||||
for (int i = 0; i < S_v; i++) {
|
||||
s[i] = curr_state[i * S_v + col];
|
||||
|
|
@ -114,6 +121,15 @@ __global__ void gated_delta_net_cuda(const float * q,
|
|||
}
|
||||
}
|
||||
|
||||
static size_t calculate_smem(const int sv, int cc)
|
||||
{
|
||||
size_t smem = 0;
|
||||
if ((GGML_CUDA_CC_IS_AMD(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_RDNA4(cc)) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
|
||||
smem = sv * sv * sizeof(float);
|
||||
}
|
||||
return smem;
|
||||
}
|
||||
|
||||
template <bool KDA>
|
||||
static void launch_gated_delta_net(
|
||||
const float * q_d, const float * k_d, const float * v_d,
|
||||
|
|
@ -129,25 +145,36 @@ static void launch_gated_delta_net(
|
|||
dim3 grid_dims(H, n_seqs, 1);
|
||||
dim3 block_dims(S_v, 1, 1);
|
||||
|
||||
int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
switch (S_v) {
|
||||
case 32:
|
||||
gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
|
||||
case 32: {
|
||||
constexpr int sv = 32;
|
||||
size_t smem = calculate_smem(sv, cc);
|
||||
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, rq1, rq3, scale);
|
||||
break;
|
||||
case 64:
|
||||
gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
|
||||
}
|
||||
case 64: {
|
||||
constexpr int sv = 64;
|
||||
size_t smem = calculate_smem(sv, cc);
|
||||
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, rq1, rq3, scale);
|
||||
break;
|
||||
case 128:
|
||||
gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
|
||||
}
|
||||
case 128: {
|
||||
constexpr int sv = 128;
|
||||
size_t smem = calculate_smem(sv, cc);
|
||||
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, rq1, rq3, scale);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
|
|||
int row = tid / load_cols;
|
||||
int col = tid % load_cols;
|
||||
#pragma unroll
|
||||
for (int idx = tid; idx < total_elems; idx += split_d_inner) {
|
||||
for (int idx = 0; idx < total_elems; idx += split_d_inner) {
|
||||
if (row < (int)split_d_inner) {
|
||||
smem[row * n_cols + col] = x_block[row * stride_x + col];
|
||||
}
|
||||
|
|
@ -84,6 +84,9 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
|
|||
col += split_d_inner;
|
||||
row += col / load_cols;
|
||||
col = col % load_cols;
|
||||
if (idx >= total_elems - tid - split_d_inner) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ struct ggml_metal {
|
|||
uint64_t fuse_cnt[GGML_OP_COUNT];
|
||||
|
||||
// capture state
|
||||
bool capture_next_compute;
|
||||
int capture_compute;
|
||||
bool capture_started;
|
||||
|
||||
id<MTLCaptureScope> capture_scope;
|
||||
|
|
@ -158,10 +158,17 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
|
|||
GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, res->use_concurrency ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false");
|
||||
|
||||
res->capture_next_compute = false;
|
||||
res->capture_compute = 0;
|
||||
res->capture_started = false;
|
||||
res->capture_scope = nil;
|
||||
|
||||
{
|
||||
const char * val = getenv("GGML_METAL_CAPTURE_COMPUTE");
|
||||
if (val) {
|
||||
res->capture_compute = atoi(val);
|
||||
}
|
||||
}
|
||||
|
||||
res->has_error = false;
|
||||
|
||||
res->gf = nil;
|
||||
|
|
@ -458,9 +465,13 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
|
|||
|
||||
ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;
|
||||
|
||||
const bool use_capture = ctx->capture_next_compute;
|
||||
if (ctx->capture_compute >= 0) {
|
||||
ctx->capture_compute--;
|
||||
}
|
||||
|
||||
const bool use_capture = ctx->capture_compute == 0;
|
||||
if (use_capture) {
|
||||
ctx->capture_next_compute = false;
|
||||
ctx->capture_compute = -1;
|
||||
|
||||
// make sure all previous computations have finished before starting the capture
|
||||
if (ctx->cmd_buf_last) {
|
||||
|
|
@ -469,6 +480,10 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
|
|||
}
|
||||
|
||||
if (!ctx->capture_started) {
|
||||
NSString * path = [NSString stringWithFormat:@"/tmp/perf-metal-%d.gputrace", getpid()];
|
||||
|
||||
GGML_LOG_WARN("%s: capturing graph in %s\n", __func__, [path UTF8String]);
|
||||
|
||||
// create capture scope
|
||||
id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);
|
||||
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:device];
|
||||
|
|
@ -476,7 +491,7 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
|
|||
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
|
||||
descriptor.captureObject = ctx->capture_scope;
|
||||
descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
|
||||
descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
|
||||
descriptor.outputURL = [NSURL fileURLWithPath:path];
|
||||
|
||||
NSError * error = nil;
|
||||
if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
|
||||
|
|
@ -683,7 +698,7 @@ void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {
|
|||
idx_end,
|
||||
ctx->use_fusion,
|
||||
ctx->use_concurrency,
|
||||
ctx->capture_next_compute,
|
||||
ctx->capture_compute,
|
||||
ctx->debug_graph,
|
||||
ctx->debug_fusion);
|
||||
|
||||
|
|
@ -718,5 +733,5 @@ bool ggml_metal_supports_family(ggml_metal_t ctx, int family) {
|
|||
}
|
||||
|
||||
void ggml_metal_capture_next_compute(ggml_metal_t ctx) {
|
||||
ctx->capture_next_compute = true;
|
||||
ctx->capture_compute = 1;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@
|
|||
#define N_R0_Q4_K 2
|
||||
#define N_SG_Q4_K 2
|
||||
|
||||
#define N_R0_Q5_K 2
|
||||
#define N_R0_Q5_K 1
|
||||
#define N_SG_Q5_K 2
|
||||
|
||||
#define N_R0_Q6_K 2
|
||||
|
|
|
|||
|
|
@ -874,4 +874,95 @@ static bool fast_fp16_available(const int cc) {
|
|||
return true; //Intel GPUs always support FP16.
|
||||
}
|
||||
|
||||
enum class block_reduce_method {
|
||||
MAX,
|
||||
SUM,
|
||||
};
|
||||
|
||||
template<block_reduce_method method_t, typename T, int warp_size>
|
||||
struct block_reduce_policy;
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
inline constexpr bool is_any = (std::is_same_v<T, Ts> || ...);
|
||||
|
||||
template<typename...>
|
||||
inline constexpr bool ggml_sycl_dependent_false_v = false;
|
||||
|
||||
#define WARP_32_SIZE 32
|
||||
|
||||
template <typename T, int warp_size> struct block_reduce_policy<block_reduce_method::SUM, T, warp_size> {
|
||||
static T reduce(T val) {
|
||||
if constexpr (is_any<T, float, sycl::float2, sycl::half2, int>) {
|
||||
return warp_reduce_sum<warp_size>(val);
|
||||
} else {
|
||||
static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce sum");
|
||||
}
|
||||
}
|
||||
|
||||
static T sentinel() {
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
return 0.0f;
|
||||
} else if constexpr (std::is_same_v<T, sycl::float2>) {
|
||||
return sycl::float2(0.0f, 0.0f);
|
||||
} else if constexpr (std::is_same_v<T, sycl::half2>) {
|
||||
return sycl::half2(0.0f, 0.0f);
|
||||
} else if constexpr (std::is_same_v<T, int>) {
|
||||
return 0;
|
||||
} else {
|
||||
static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce sum");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int warp_size> struct block_reduce_policy<block_reduce_method::MAX, T, warp_size> {
|
||||
static T reduce(T val) {
|
||||
if constexpr (is_any<T, float, sycl::half2>) {
|
||||
return warp_reduce_max<warp_size>(val);
|
||||
} else {
|
||||
static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce max");
|
||||
}
|
||||
}
|
||||
|
||||
static T sentinel() {
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
return -INFINITY;
|
||||
} else if constexpr (std::is_same_v<T, sycl::half2>) {
|
||||
return sycl::half2(-INFINITY, -INFINITY);
|
||||
} else {
|
||||
static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce max");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <block_reduce_method reduce_method_t, int warp_size, typename T>
|
||||
static T block_reduce(T val, T * shared_vals, int block_size_template) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
val = block_reduce_policy<reduce_method_t, T,warp_size>::reduce(val);
|
||||
const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
|
||||
const int nthreads = item_ct1.get_local_range(2);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
|
||||
if (block_size > warp_size) {
|
||||
assert((block_size <= 1024) && (block_size % warp_size) == 0);
|
||||
const int warp_id = item_ct1.get_local_id(2) / warp_size;
|
||||
const int lane_id = item_ct1.get_local_id(2) % warp_size;
|
||||
if (lane_id == 0) {
|
||||
shared_vals[warp_id] = val;
|
||||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
size_t nreduce = nwarps / WARP_SIZE;
|
||||
float tmp = 0.f;
|
||||
if (lane_id < (static_cast<int>(block_size) / warp_size)) {
|
||||
for (size_t i = 0; i < nreduce; i += 1)
|
||||
{
|
||||
tmp += shared_vals[lane_id + i * WARP_SIZE];
|
||||
}
|
||||
}
|
||||
return block_reduce_policy<reduce_method_t, T, warp_size>::reduce(tmp);
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
#endif // GGML_SYCL_COMMON_HPP
|
||||
|
|
|
|||
|
|
@ -39,6 +39,11 @@ template<typename dst_t, typename src_t>
|
|||
return sycl::ext::oneapi::bfloat16(float(x));
|
||||
} else if constexpr (std::is_same_v<src_t, sycl::ext::oneapi::bfloat16>) {
|
||||
return static_cast<float>(x);
|
||||
} else if constexpr (std::is_same_v<src_t, sycl::float2> && std::is_same_v<dst_t, sycl::half2>) {
|
||||
return x.template convert<sycl::half, sycl::rounding_mode::rte>();
|
||||
} else if constexpr (std::is_same_v<src_t, sycl::float2> &&
|
||||
std::is_same_v<dst_t, sycl::vec<sycl::ext::oneapi::bfloat16, 2>>) {
|
||||
return {x.x, x.y};
|
||||
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
|
||||
return int32_t(x);
|
||||
} else {
|
||||
|
|
@ -46,4 +51,5 @@ template<typename dst_t, typename src_t>
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
#endif // GGML_SYCL_CONVERT_HPP
|
||||
|
|
|
|||
|
|
@ -9,23 +9,32 @@
|
|||
#define SYCL_LOCAL_ID_CALC(ITEM, IDX) \
|
||||
(ITEM.get_local_range(IDX) * ITEM.get_group(IDX) + ITEM.get_local_id(IDX))
|
||||
|
||||
static void acc_f32(const float * x, const float * y, float * dst, const int64_t ne,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int64_t i = SYCL_LOCAL_ID_CALC(item_ct1, 2);
|
||||
|
||||
static void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
||||
const int ne10, const int ne11, const int ne12,
|
||||
const int nb1, const int nb2, int offset, const sycl::nd_item<1> &item_ct1) {
|
||||
const int i = SYCL_LOCAL_ID_CALC(item_ct1, 0);
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
int src1_idx = i - offset;
|
||||
int oz = src1_idx / nb2;
|
||||
int oy = (src1_idx - (oz * nb2)) / nb1;
|
||||
int ox = src1_idx % nb1;
|
||||
if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
|
||||
dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
|
||||
} else {
|
||||
dst[i] = x[i];
|
||||
|
||||
int64_t src1_idx = i - offset;
|
||||
|
||||
int64_t tmp = src1_idx;
|
||||
const int64_t i13 = tmp / s13;
|
||||
tmp -= i13 * s13;
|
||||
const int64_t i12 = tmp / s12;
|
||||
tmp -= i12 * s12;
|
||||
const int64_t i11 = tmp / s11;
|
||||
tmp -= i11 * s11;
|
||||
const int64_t i10 = tmp;
|
||||
|
||||
float val = x[i];
|
||||
if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) {
|
||||
val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10];
|
||||
}
|
||||
dst[i] = val;
|
||||
}
|
||||
|
||||
/* Unary OP funcs */
|
||||
|
|
@ -364,18 +373,15 @@ static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const
|
|||
|
||||
namespace ggml_sycl_detail {
|
||||
static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
const int n_elements, const int ne10, const int ne11,
|
||||
const int ne12, const int nb1, const int nb2,
|
||||
const int offset, queue_ptr stream) {
|
||||
int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) *
|
||||
sycl::range<1>(SYCL_ACC_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_ACC_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
|
||||
item_ct1);
|
||||
});
|
||||
const int64_t n_elements, const int64_t ne10, const int64_t ne11,
|
||||
const int64_t ne12, const int64_t ne13, const int64_t s1, const int64_t s2, const int64_t s3,
|
||||
const int64_t offset, queue_ptr stream) {
|
||||
const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
|
|
@ -402,25 +408,19 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
|
|||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
#else
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
#endif
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
switch (dst->type) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
auto data_pts = cast_data<sycl::half>(dst);
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
auto data_pts = cast_data<float>(dst);
|
||||
|
|
@ -434,14 +434,10 @@ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx,
|
|||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
#else
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
#endif
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
|
@ -463,7 +459,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c
|
|||
GGML_ASSERT(src0->type == src1->type);
|
||||
}
|
||||
switch (dst->type) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
sycl::half * src0_p = (sycl::half *) src0_d;
|
||||
|
|
@ -484,7 +479,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c
|
|||
std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
float * src0_p = (float *) src0_d;
|
||||
|
|
@ -513,13 +507,9 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c
|
|||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
#else
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
#endif
|
||||
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
|
|
@ -530,7 +520,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
|
|||
const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2];
|
||||
const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3];
|
||||
switch (dst->type) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
auto data_pts = cast_data<sycl::half>(dst);
|
||||
|
|
@ -539,7 +528,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
|
|||
main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
auto data_pts = cast_data<float>(dst);
|
||||
|
|
@ -868,22 +856,31 @@ static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tens
|
|||
}
|
||||
|
||||
static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
const float * src1_dd = static_cast<const float*>(dst->src[1]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
|
||||
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
|
||||
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
|
||||
int offset = dst->op_params[3] / 4; // offset in bytes
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(dst->nb[0] == ggml_element_size(dst));
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(dst));
|
||||
|
||||
ggml_sycl_detail::acc_f32_sycl(src0_dd, src1_dd, dst_dd, (int)ggml_nelements(dst), (int)dst->src[1]->ne[0], (int)dst->src[1]->ne[1], (int)dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
|
||||
const int64_t s1 = dst->op_params[0] / sizeof(float);
|
||||
const int64_t s2 = dst->op_params[1] / sizeof(float);
|
||||
const int64_t s3 = dst->op_params[2] / sizeof(float);
|
||||
const int64_t offset = dst->op_params[3] / sizeof(float);
|
||||
|
||||
ggml_sycl_detail::acc_f32_sycl(src0_d, src1_d, dst_d, ggml_nelements(dst),
|
||||
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
|
||||
s1, s2, s3, offset, stream);
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
|
|
|
|||
|
|
@ -4145,6 +4145,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|||
case GGML_OP_ROPE:
|
||||
ggml_sycl_rope(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ROPE_BACK:
|
||||
ggml_sycl_rope_back(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_IM2COL:
|
||||
ggml_sycl_im2col(ctx, dst);
|
||||
break;
|
||||
|
|
@ -4851,6 +4854,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
return max_bias == 0.0f;
|
||||
}
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK:
|
||||
case GGML_OP_IM2COL:
|
||||
return true;
|
||||
case GGML_OP_UPSCALE:
|
||||
|
|
@ -4872,8 +4876,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
k > 0 && k <= 32;
|
||||
}
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
case GGML_OP_ACC:
|
||||
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
|
||||
case GGML_OP_PAD:
|
||||
// TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985
|
||||
if (ggml_get_op_params_i32(op, 8) != 0) {
|
||||
|
|
|
|||
|
|
@ -202,47 +202,34 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
|
|||
}
|
||||
}
|
||||
|
||||
static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
|
||||
const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
|
||||
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
||||
item_ct1.get_local_id(1);
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int nthreads = item_ct1.get_local_range(2);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
template<int warp_size>
|
||||
static void l2_norm_f32(const float * x, float * dst, const int ncols,
|
||||
const int64_t stride_row, const int64_t stride_channel,
|
||||
const int64_t stride_sample, const float eps,
|
||||
const sycl::nd_item<3>& item_ct1, float* s_sum, const int block_size) {
|
||||
const int nrows = item_ct1.get_group_range(2);
|
||||
const int nchannels = item_ct1.get_group_range(1);
|
||||
|
||||
const int row = item_ct1.get_group(2);
|
||||
const int channel = item_ct1.get_group(1);
|
||||
const int sample = item_ct1.get_group(0);
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
|
||||
x += sample*stride_sample + channel*stride_channel + row*stride_row;
|
||||
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[row * ncols + col];
|
||||
const float xi = x[col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||
if (block_size > WARP_SIZE) {
|
||||
|
||||
int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
||||
int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = tmp;
|
||||
}
|
||||
/*
|
||||
DPCT1118:3: SYCL group functions and algorithms must be encountered in
|
||||
converged control flow. You may need to adjust the code.
|
||||
*/
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
size_t nreduce = nwarps / WARP_SIZE;
|
||||
tmp = 0.f;
|
||||
for (size_t i = 0; i < nreduce; i += 1)
|
||||
{
|
||||
tmp += s_sum[lane_id + i * WARP_SIZE];
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||
}
|
||||
|
||||
const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
|
||||
tmp = block_reduce<block_reduce_method::SUM, warp_size>(tmp, s_sum, block_size);
|
||||
const float scale = sycl::rsqrt(sycl::fmax(tmp, eps * eps));
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row * ncols + col] = scale * x[row * ncols + col];
|
||||
dst[col] = scale * x[col];
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -369,42 +356,50 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
|
|||
}
|
||||
}
|
||||
|
||||
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
const int nrows, const float eps,
|
||||
queue_ptr stream, int device) {
|
||||
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
||||
template<int warp_size>
|
||||
static void l2_norm_f32_sycl(const float * x,
|
||||
float * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
const int nchannels,
|
||||
const int nsamples,
|
||||
const int64_t stride_row,
|
||||
const int64_t stride_channel,
|
||||
const int64_t stride_sample,
|
||||
const float eps,
|
||||
queue_ptr stream,
|
||||
int device) {
|
||||
const dpct::dim3 blocks_num(nrows, nchannels, nsamples);
|
||||
|
||||
if (ncols < 1024) {
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
const dpct::dim3 block_dims(warp_size, 1, 1);
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
sycl::nd_range<3>(blocks_num * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
nullptr, WARP_SIZE);
|
||||
[[sycl::reqd_sub_group_size(warp_size)]] {
|
||||
l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
|
||||
nullptr, warp_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
else {
|
||||
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
||||
assert(work_group_size % (warp_size * warp_size) == 0);
|
||||
const sycl::range<3> block_dims(1, 1, work_group_size);
|
||||
/*
|
||||
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
|
||||
the limit. To get the device limit, query
|
||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||
*/
|
||||
int lsm_size = block_dims[2] > warp_size ? work_group_size / warp_size * sizeof(float): 0;
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
||||
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(lsm_size),
|
||||
cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
sycl::nd_range<3>(blocks_num * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
[[sycl::reqd_sub_group_size(warp_size)]] {
|
||||
l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample,
|
||||
eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
@ -634,21 +629,28 @@ void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * d
|
|||
}
|
||||
|
||||
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
const int64_t ne00 = dst->src[0]->ne[0];
|
||||
const int64_t nrows = ggml_nrows(dst->src[0]);
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
||||
const size_t ts0 = ggml_type_size(src0->type);
|
||||
GGML_ASSERT(nb00 == ts0);
|
||||
const int64_t s01 = nb01 / ts0;
|
||||
const int64_t s02 = nb02 / ts0;
|
||||
const int64_t s03 = nb03 / ts0;
|
||||
|
||||
/*support both WARP_SIZE or WARP_32_SIZE in code
|
||||
choose by hardware for better performance
|
||||
*/
|
||||
l2_norm_f32_sycl<WARP_SIZE>(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream, ctx.device);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include "rope.hpp"
|
||||
#include "convert.hpp"
|
||||
#include "ggml-sycl/common.hpp"
|
||||
#include "ggml.h"
|
||||
|
||||
|
|
@ -15,366 +16,489 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|||
return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
|
||||
}
|
||||
|
||||
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
||||
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
||||
static void rope_yarn(
|
||||
float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
|
||||
float * cos_theta, float * sin_theta) {
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
template <bool forward>
|
||||
static void rope_yarn(const float theta_extrap, const float freq_scale,
|
||||
const rope_corr_dims corr_dims, const int64_t i0,
|
||||
const float ext_factor, float mscale, float &cos_theta,
|
||||
float &sin_theta) {
|
||||
float theta_interp = freq_scale * theta_extrap;
|
||||
float theta = theta_interp;
|
||||
if (ext_factor != 0.0f) {
|
||||
float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
|
||||
float ramp_mix =
|
||||
rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
|
||||
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
|
||||
// Get n-d magnitude scaling corrected for interpolation
|
||||
mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
|
||||
}
|
||||
*cos_theta = sycl::cos(theta) * mscale;
|
||||
*sin_theta = sycl::sin(theta) * mscale;
|
||||
cos_theta = sycl::cos(theta) * mscale;
|
||||
sin_theta = sycl::sin(theta) * mscale;
|
||||
if (!forward) {
|
||||
sin_theta *= -1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool has_ff>
|
||||
static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
||||
const int32_t * pos, float freq_scale, float ext_factor, float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
|
||||
const sycl::nd_item<3> & item_ct1) {
|
||||
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
|
||||
template <bool forward, bool has_ff, typename T, typename D>
|
||||
static void rope_norm(const T *x, D *dst, const int ne00, const int ne01,
|
||||
const int ne02, const int s01, const int s02,
|
||||
const int s03, const int s1, const int s2, const int s3,
|
||||
const int n_dims, const int32_t *pos,
|
||||
const float freq_scale, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float theta_scale, const float *freq_factors,
|
||||
const int64_t *row_indices, const int set_rows_stride) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
item_ct1.get_local_id(1));
|
||||
|
||||
if (i0 >= ne0) {
|
||||
if (i0 >= ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
||||
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
|
||||
const int row0 = row % ne1;
|
||||
const int channel0 = row / ne1;
|
||||
const uint32_t i3 = row_dst / (ne01 * ne02);
|
||||
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
|
||||
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
|
||||
|
||||
const int i = row * ne0 + i0;
|
||||
const int i2 = channel0 * s2 + row0 * s1 + i0;
|
||||
int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3;
|
||||
const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
|
||||
if (set_rows_stride != 0) {
|
||||
idst = i1 * s1 + i0;
|
||||
idst += row_indices[i2] * set_rows_stride;
|
||||
}
|
||||
|
||||
const auto &store_coaelsced = [&](float x0, float x1) {
|
||||
if constexpr (std::is_same_v<float, D>) {
|
||||
sycl::float2 v = sycl::float2(x0, x1);
|
||||
ggml_sycl_memcpy_1<8>(dst + idst, &v);
|
||||
} else if constexpr (std::is_same_v<sycl::half, D>) {
|
||||
sycl::half2 v = sycl::half2(x0, x1);
|
||||
ggml_sycl_memcpy_1<4>(dst + idst, &v);
|
||||
}
|
||||
};
|
||||
if (i0 >= n_dims) {
|
||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
|
||||
store_coaelsced(x[ix + 0], x[ix + 1]);
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
||||
const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
|
||||
ext_factor, attn_factor, cos_theta, sin_theta);
|
||||
|
||||
const float x0 = x[i2 + 0];
|
||||
const float x1 = x[i2 + 1];
|
||||
const float x0 = x[ix + 0];
|
||||
const float x1 = x[ix + 1];
|
||||
|
||||
dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst[i + 1] = x0 * sin_theta + x1 * cos_theta;
|
||||
store_coaelsced(x0 * cos_theta - x1 * sin_theta,
|
||||
x0 * sin_theta + x1 * cos_theta);
|
||||
}
|
||||
|
||||
template <typename T, bool has_ff>
|
||||
static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
||||
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
|
||||
const sycl::nd_item<3> & item_ct1) {
|
||||
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
|
||||
template <bool forward, bool has_ff, typename T, typename D>
|
||||
static void rope_neox(const T *x, D *dst, const int ne00, const int ne01,
|
||||
const int ne02, const int s01, const int s02,
|
||||
const int s03, const int s1, const int s2, const int s3,
|
||||
const int n_dims, const int32_t *pos,
|
||||
const float freq_scale, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float theta_scale, const float *freq_factors,
|
||||
const int64_t *row_indices, const int set_rows_stride) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
item_ct1.get_local_id(1));
|
||||
|
||||
if (i0 >= ne0) {
|
||||
if (i0 >= ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
||||
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
|
||||
const int row0 = row % ne1;
|
||||
const int channel0 = row / ne1;
|
||||
const uint32_t i3 = row_dst / (ne01 * ne02);
|
||||
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
|
||||
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
|
||||
|
||||
const int i = row * ne0 + i0 / 2;
|
||||
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
|
||||
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
||||
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
|
||||
if (set_rows_stride != 0) {
|
||||
idst = i1 * s1 + i0 / 2;
|
||||
idst += row_indices[i2] * set_rows_stride;
|
||||
}
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
|
||||
dst[idst + i0 / 2 + 0] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 0]);
|
||||
dst[idst + i0 / 2 + 1] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 1]);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
||||
const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
|
||||
ext_factor, attn_factor, cos_theta, sin_theta);
|
||||
|
||||
const float x0 = x[i2 + 0];
|
||||
const float x1 = x[i2 + n_dims / 2];
|
||||
const float x0 = x[ix + 0];
|
||||
const float x1 = x[ix + n_dims / 2];
|
||||
|
||||
dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
|
||||
dst[idst + 0] = ggml_sycl_cast<D>(x0 * cos_theta - x1 * sin_theta);
|
||||
dst[idst + n_dims / 2] = ggml_sycl_cast<D>(x0 * sin_theta + x1 * cos_theta);
|
||||
}
|
||||
|
||||
template <typename T, bool has_ff>
|
||||
static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
||||
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
|
||||
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float theta_scale, const float * freq_factors, const mrope_sections sections,
|
||||
const bool is_imrope, const sycl::nd_item<3> & item_ct1) {
|
||||
// get index pos
|
||||
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
|
||||
if (i0 >= ne0) {
|
||||
template <bool forward, bool has_ff, typename T>
|
||||
static void rope_multi(const T *x, T *dst, const int ne00, const int ne01,
|
||||
const int ne02, const int s01, const int s02,
|
||||
const int s03, const int s1, const int s2, const int s3,
|
||||
const int n_dims, const int32_t *pos,
|
||||
const float freq_scale, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float theta_scale, const float *freq_factors,
|
||||
const mrope_sections sections, const bool is_imrope) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
item_ct1.get_local_id(1));
|
||||
|
||||
if (i0 >= ne00) {
|
||||
return;
|
||||
}
|
||||
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
|
||||
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
const int idst = (row_dst * ne0) + (i0 / 2);
|
||||
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
|
||||
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
|
||||
const uint32_t i3 = row_dst / (ne01 * ne02);
|
||||
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
|
||||
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
|
||||
|
||||
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
||||
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
|
||||
dst[idst + i0 / 2 + 0] = x[ix + i0 / 2 + 0];
|
||||
dst[idst + i0 / 2 + 1] = x[ix + i0 / 2 + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
||||
const int sect_dims =
|
||||
sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
||||
const int sec_w = sections.v[1] + sections.v[0];
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
|
||||
|
||||
float theta_base = 0.0;
|
||||
if (is_imrope) {
|
||||
if (sector % 3 == 1 && sector < 3 * sections.v[1]) {
|
||||
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {
|
||||
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
|
||||
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
|
||||
theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
|
||||
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
|
||||
theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
|
||||
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
|
||||
theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
|
||||
} else {
|
||||
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
|
||||
theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
|
||||
}
|
||||
} else {
|
||||
if (sector < sections.v[0]) {
|
||||
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sections.v[0] && sector < sec_w) {
|
||||
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w + sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
|
||||
theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
|
||||
} else if (sector >= sections.v[0] && sector < sec_w) {
|
||||
theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
|
||||
} else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
||||
theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
|
||||
} else if (sector >= sec_w + sections.v[2]) {
|
||||
theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
|
||||
}
|
||||
}
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
const float x0 = x[ix + 0];
|
||||
const float x1 = x[ix + n_dims/2];
|
||||
|
||||
// store results in dst
|
||||
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta;
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
|
||||
ext_factor, attn_factor, cos_theta, sin_theta);
|
||||
|
||||
const float x0 = x[ix + 0];
|
||||
const float x1 = x[ix + n_dims / 2];
|
||||
|
||||
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst[idst + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
|
||||
}
|
||||
|
||||
template <bool forward, bool has_ff, typename T>
|
||||
static void rope_vision(const T *x, T *dst, const int ne00, const int ne01,
|
||||
const int ne02, const int s01, const int s02,
|
||||
const int s03, const int s1, const int s2, const int s3,
|
||||
const int n_dims, const int32_t *pos,
|
||||
const float freq_scale, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float theta_scale, const float *freq_factors,
|
||||
const mrope_sections sections) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
item_ct1.get_local_id(1));
|
||||
|
||||
|
||||
template <typename T, bool has_ff>
|
||||
static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
||||
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
|
||||
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float theta_scale, const float * freq_factors, const mrope_sections sections,
|
||||
const sycl::nd_item<3> & item_ct1) {
|
||||
// get index pos
|
||||
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
|
||||
if (i0 >= ne0) {
|
||||
if (i0 >= ne00) {
|
||||
return;
|
||||
}
|
||||
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
const int idst = (row_dst * ne0) + (i0 / 2);
|
||||
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
|
||||
|
||||
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
|
||||
const uint32_t i3 = row_dst / (ne01 * ne02);
|
||||
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
|
||||
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
|
||||
|
||||
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
||||
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
|
||||
const int sect_dims = sections.v[0] + sections.v[1];
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
const int sec_w = sections.v[1] + sections.v[0];
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
|
||||
float theta_base = 0.0f;
|
||||
float theta_base = 0.0;
|
||||
if (sector < sections.v[0]) {
|
||||
const int p = sector;
|
||||
theta_base = pos[channel_x] * sycl::pow(theta_scale, (float) p);
|
||||
} else {
|
||||
theta_base = pos[i2] * dpct::pow(theta_scale, p);
|
||||
} else if (sector >= sections.v[0] && sector < sec_w) {
|
||||
const int p = sector - sections.v[0];
|
||||
theta_base = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p);
|
||||
theta_base = pos[i2 + ne02] * dpct::pow(theta_scale, p);
|
||||
}
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
|
||||
ext_factor, attn_factor, cos_theta, sin_theta);
|
||||
|
||||
const float x0 = x[ix + 0];
|
||||
const float x1 = x[ix + n_dims];
|
||||
|
||||
// store results in dst
|
||||
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
|
||||
const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base,
|
||||
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float * freq_factors, queue_ptr stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||
const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
||||
const sycl::range<3> block_nums(1, num_blocks_x, nr);
|
||||
template <bool forward, typename T, typename D>
|
||||
static void
|
||||
rope_norm_sycl(const T *x, D *dst, const int ne00, const int ne01,
|
||||
const int ne02, const int s01, const int s02, const int s03,
|
||||
const int s1, const int s2, const int s3, const int n_dims,
|
||||
const int nr, const int32_t *pos, const float freq_scale,
|
||||
const float freq_base, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float *freq_factors, const int64_t *row_indices,
|
||||
const int set_rows_stride, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x =
|
||||
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
|
||||
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
|
||||
if (freq_factors == nullptr) {
|
||||
/*
|
||||
DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
|
||||
the limit. To get the device limit, query
|
||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||
*/
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, item_ct1);
|
||||
});
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
GGML_UNUSED(item_ct1);
|
||||
rope_norm<forward, false>(
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
||||
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, row_indices, set_rows_stride);
|
||||
});
|
||||
} else {
|
||||
/*
|
||||
DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
|
||||
the limit. To get the device limit, query
|
||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||
*/
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, item_ct1);
|
||||
});
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
GGML_UNUSED(item_ct1);
|
||||
rope_norm<forward, true>(
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
||||
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, row_indices, set_rows_stride);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
|
||||
const int n_dims, const int nr, const int32_t * pos, const float freq_scale,
|
||||
const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||
const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
||||
const sycl::range<3> block_nums(1, num_blocks_x, nr);
|
||||
template <bool forward, typename T, typename D>
|
||||
static void
|
||||
rope_neox_sycl(const T *x, D *dst, const int ne00, const int ne01,
|
||||
const int ne02, const int s01, const int s02, const int s03,
|
||||
const int s1, const int s2, const int s3, const int n_dims,
|
||||
const int nr, const int32_t *pos, const float freq_scale,
|
||||
const float freq_base, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float *freq_factors, const int64_t *row_indices,
|
||||
const int set_rows_stride, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x =
|
||||
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
|
||||
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
|
||||
if (freq_factors == nullptr) {
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, item_ct1);
|
||||
});
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
GGML_UNUSED(item_ct1);
|
||||
rope_neox<forward, false>(
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
||||
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, row_indices, set_rows_stride);
|
||||
});
|
||||
} else {
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, item_ct1);
|
||||
});
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
GGML_UNUSED(item_ct1);
|
||||
rope_neox<forward, true>(
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
||||
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, row_indices, set_rows_stride);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
||||
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
|
||||
const float freq_scale, const float freq_base, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
|
||||
const mrope_sections sections, const bool is_imrope, queue_ptr stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
||||
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
|
||||
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
|
||||
template <bool forward, typename T>
|
||||
static void
|
||||
rope_multi_sycl(const T *x, T *dst, const int ne00, const int ne01,
|
||||
const int ne02, const int s01, const int s02, const int s03,
|
||||
const int s1, const int s2, const int s3, const int n_dims,
|
||||
const int nr, const int32_t *pos, const float freq_scale,
|
||||
const float freq_base, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float *freq_factors, const mrope_sections sections,
|
||||
const bool is_imrope, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x =
|
||||
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
|
||||
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||
|
||||
const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
|
||||
// Add FP16 capability check if T could be sycl::half
|
||||
if constexpr (std::is_same_v<T, sycl::half>) {
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
}
|
||||
// launch kernel
|
||||
if (freq_factors == nullptr) {
|
||||
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
|
||||
});
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
GGML_UNUSED(item_ct1);
|
||||
rope_multi<forward, false, T>(
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
||||
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, sections, is_imrope);
|
||||
});
|
||||
} else {
|
||||
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
|
||||
});
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
GGML_UNUSED(item_ct1);
|
||||
rope_multi<forward, true, T>(
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
||||
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, sections, is_imrope);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <bool forward, typename T>
|
||||
static void
|
||||
rope_vision_sycl(const T *x, T *dst, const int ne00, const int ne01,
|
||||
const int ne02, const int s01, const int s02, const int s03,
|
||||
const int s1, const int s2, const int s3, const int n_dims,
|
||||
const int nr, const int32_t *pos, const float freq_scale,
|
||||
const float freq_base, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float *freq_factors, const mrope_sections sections,
|
||||
dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x =
|
||||
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
|
||||
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||
|
||||
|
||||
// rope vision
|
||||
template <typename T>
|
||||
static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
||||
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
|
||||
const float freq_scale, const float freq_base, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
|
||||
const mrope_sections sections, queue_ptr stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
||||
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
|
||||
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
|
||||
|
||||
const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
|
||||
// Add FP16 capability check if T could be sycl::half
|
||||
if constexpr (std::is_same_v<T, sycl::half>) {
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
}
|
||||
// launch kernel
|
||||
if (freq_factors == nullptr) {
|
||||
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
||||
});
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
GGML_UNUSED(item_ct1);
|
||||
rope_vision<forward, false, T>(
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
||||
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, sections);
|
||||
});
|
||||
} else {
|
||||
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
||||
});
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
GGML_UNUSED(item_ct1);
|
||||
rope_vision<forward, true, T>(
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
||||
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, sections);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
template <bool forward>
|
||||
void ggml_sycl_op_rope_impl(ggml_backend_sycl_context &ctx, ggml_tensor *dst,
|
||||
const ggml_tensor *set_rows = nullptr) {
|
||||
const ggml_tensor *src0 = dst->src[0];
|
||||
const ggml_tensor *src1 = dst->src[1];
|
||||
const ggml_tensor *src2 = dst->src[2];
|
||||
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
const int64_t ne00 = dst->src[0]->ne[0]; // head dims
|
||||
const int64_t ne01 = dst->src[0]->ne[1]; // num heads
|
||||
const int64_t ne02 = dst->src[0]->ne[2]; // num heads
|
||||
const int64_t nr = ggml_nrows(dst->src[0]);
|
||||
const float *src0_d = (const float *)src0->data;
|
||||
const float *src1_d = (const float *)src1->data;
|
||||
|
||||
const size_t s01 = dst->src[0]->nb[1] / ggml_type_size(dst->src[0]->type);
|
||||
const size_t s02 = dst->src[0]->nb[2] / ggml_type_size(dst->src[0]->type);
|
||||
void *dst_d = dst->data;
|
||||
const int64_t *row_indices = nullptr;
|
||||
ggml_type dst_type = dst->type;
|
||||
int set_rows_stride = 0;
|
||||
|
||||
if (set_rows != nullptr) {
|
||||
GGML_ASSERT(forward);
|
||||
dst_d = set_rows->data;
|
||||
row_indices = (const int64_t *)set_rows->src[1]->data;
|
||||
dst_type = set_rows->type;
|
||||
set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
|
||||
}
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src0->type == dst->type ||
|
||||
(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));
|
||||
|
||||
const int64_t ne00 = src0->ne[0]; // head dims
|
||||
const int64_t ne01 = src0->ne[1]; // num heads
|
||||
const int64_t ne02 = src0->ne[2]; // num heads
|
||||
const int64_t nr = ggml_nrows(src0);
|
||||
|
||||
const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
|
||||
const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
|
||||
const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);
|
||||
|
||||
const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);
|
||||
const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);
|
||||
const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);
|
||||
|
||||
const int n_dims = ((int32_t *)dst->op_params)[1];
|
||||
const int mode = ((int32_t *)dst->op_params)[2];
|
||||
const int n_ctx_orig = ((int32_t *)dst->op_params)[4];
|
||||
mrope_sections sections;
|
||||
|
||||
// RoPE alteration for extended context
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
float ext_factor;
|
||||
|
|
@ -382,13 +506,13 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
|||
float beta_fast;
|
||||
float beta_slow;
|
||||
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
memcpy(§ions.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
||||
memcpy(&freq_base, (int32_t *)dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *)dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *)dst->op_params + 7, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *)dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *)dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *)dst->op_params + 10, sizeof(float));
|
||||
memcpy(§ions.v, (int32_t *)dst->op_params + 11, sizeof(int) * 4);
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
|
|
@ -396,82 +520,122 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
|||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_mrope) {
|
||||
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
|
||||
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 ||
|
||||
sections.v[2] > 0);
|
||||
}
|
||||
|
||||
if (is_vision) {
|
||||
GGML_ASSERT(n_dims == ne00/2);
|
||||
GGML_ASSERT(n_dims == ne00 / 2);
|
||||
}
|
||||
|
||||
const int32_t * pos = (const int32_t *) dst->src[1]->data;
|
||||
const int32_t *pos = (const int32_t *)src1_d;
|
||||
|
||||
const float * freq_factors = nullptr;
|
||||
if (dst->src[2] != nullptr) {
|
||||
freq_factors = (const float *) dst->src[2]->data;
|
||||
const float *freq_factors = nullptr;
|
||||
if (src2 != nullptr) {
|
||||
freq_factors = (const float *)src2->data;
|
||||
}
|
||||
|
||||
rope_corr_dims corr_dims;
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast,
|
||||
beta_slow, corr_dims.v);
|
||||
|
||||
// compute
|
||||
if (is_neox) {
|
||||
GGML_SYCL_DEBUG("%s: neox path\n", __func__);
|
||||
if (dst->src[0]->type == GGML_TYPE_F32) {
|
||||
rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
|
||||
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
|
||||
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||
rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
|
||||
n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
||||
main_stream);
|
||||
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
|
||||
rope_neox_sycl<forward, float, float>(
|
||||
(const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
|
||||
s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
||||
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
|
||||
set_rows_stride, stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
|
||||
rope_neox_sycl<forward, float, sycl::half>(
|
||||
(const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
|
||||
s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
||||
row_indices, set_rows_stride, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
|
||||
rope_neox_sycl<forward, sycl::half, sycl::half>(
|
||||
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
|
||||
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
||||
row_indices, set_rows_stride, stream);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
GGML_ABORT("Fatal error: Tensor type unsupported!");
|
||||
}
|
||||
} else if (is_mrope && !is_vision) {
|
||||
GGML_SYCL_DEBUG("%s: mrope path\n", __func__);
|
||||
if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||
rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
|
||||
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
||||
freq_factors, sections, is_imrope, main_stream);
|
||||
} else if (dst->src[0]->type == GGML_TYPE_F32) {
|
||||
rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
|
||||
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
|
||||
is_imrope, main_stream);
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_multi_sycl<forward>((const float *)src0_d, (float *)dst_d,
|
||||
ne00, ne01, ne02, s01, s02, s03, s1, s2,
|
||||
s3, n_dims, nr, pos, freq_scale, freq_base,
|
||||
ext_factor, attn_factor, corr_dims,
|
||||
freq_factors, sections, is_imrope, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_multi_sycl<forward>(
|
||||
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
|
||||
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
||||
sections, is_imrope, stream);
|
||||
} else {
|
||||
GGML_ABORT("Fatal error: Tensor type unsupported!");
|
||||
}
|
||||
} else if (is_vision) {
|
||||
GGML_SYCL_DEBUG("%s: vision path\n", __func__);
|
||||
if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||
rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01,
|
||||
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
||||
freq_factors, sections, main_stream);
|
||||
} else if (dst->src[0]->type == GGML_TYPE_F32) {
|
||||
rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
|
||||
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
|
||||
main_stream);
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_vision_sycl<forward>(
|
||||
(const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
|
||||
s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
||||
ext_factor, attn_factor, corr_dims, freq_factors, sections,
|
||||
stream);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_vision_sycl<forward>(
|
||||
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
|
||||
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
||||
sections, stream);
|
||||
} else {
|
||||
GGML_ABORT("Fatal error: Tensor type unsupported!");
|
||||
}
|
||||
} else {
|
||||
GGML_SYCL_DEBUG("%s: norm path\n", __func__);
|
||||
if (dst->src[0]->type == GGML_TYPE_F32) {
|
||||
rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
|
||||
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
|
||||
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||
rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
|
||||
n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
||||
main_stream);
|
||||
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
|
||||
rope_norm_sycl<forward, float, float>(
|
||||
(const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
|
||||
s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
||||
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
|
||||
set_rows_stride, stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
|
||||
rope_norm_sycl<forward, float, sycl::half>(
|
||||
(const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
|
||||
s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
||||
row_indices, set_rows_stride, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
|
||||
rope_norm_sycl<forward, sycl::half, sycl::half>(
|
||||
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
|
||||
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
||||
row_indices, set_rows_stride, stream);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
GGML_ABORT("Fatal error: Tensor type unsupported!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
void ggml_sycl_rope(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
|
||||
ggml_sycl_op_rope(ctx, dst);
|
||||
|
||||
ggml_sycl_op_rope_impl<true>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_rope_back(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
|
||||
ggml_sycl_op_rope_impl<false>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_rope_fused(ggml_backend_sycl_context &ctx, ggml_tensor *rope,
|
||||
ggml_tensor *set_rows) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, rope, /*num_src=*/3);
|
||||
ggml_sycl_op_rope_impl<true>(ctx, rope, set_rows);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,12 @@
|
|||
|
||||
#include "common.hpp"
|
||||
|
||||
#define SYCL_ROPE_BLOCK_SIZE 256
|
||||
|
||||
void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
|
||||
|
||||
void ggml_sycl_rope_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_rope_fused(ggml_backend_sycl_context & ctx, ggml_tensor * dst, ggml_tensor * set_rows);
|
||||
|
||||
#endif // GGML_SYCL_ROPE_HPP
|
||||
|
|
|
|||
|
|
@ -42,11 +42,20 @@
|
|||
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
|
||||
|
||||
// Matrix-vector multiplication parameters
|
||||
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
|
||||
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
|
||||
|
||||
// Must be multiple of 4 to work with vectorized paths, and must divide
|
||||
// mul_mat_vec wg size
|
||||
#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
|
||||
#define WEBGPU_MUL_MAT_VEC_TILE_K 256
|
||||
#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64
|
||||
#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256
|
||||
|
||||
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64
|
||||
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256
|
||||
|
||||
// Requires 32 threads per output (wg_size/outputs_per_wg == 32)
|
||||
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
|
||||
// Requires at least two (and multiple of 2) k-quant blocks per tile
|
||||
#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512
|
||||
|
||||
// default size for legacy matrix multiplication
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE 256
|
||||
|
|
@ -199,7 +208,8 @@ struct ggml_webgpu_binary_pipeline_key {
|
|||
bool src_overlap;
|
||||
|
||||
bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
|
||||
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
|
||||
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap &&
|
||||
src_overlap == other.src_overlap;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -749,6 +759,36 @@ class ggml_webgpu_shader_lib {
|
|||
std::vector<std::string> defines;
|
||||
std::string variant = "mul_mat_vec";
|
||||
|
||||
// src0 type (matrix row)
|
||||
switch (context.src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("SRC0_INNER_TYPE=f32");
|
||||
defines.push_back("MUL_ACC_FLOAT");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
defines.push_back("MUL_ACC_FLOAT");
|
||||
variant += "_f16";
|
||||
break;
|
||||
default:
|
||||
{
|
||||
// Quantized types: use helpers but accumulate in f16
|
||||
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
||||
std::string src0_name = src0_traits->type_name;
|
||||
std::string type_upper = src0_name;
|
||||
variant += "_" + src0_name;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back("MUL_ACC_" + type_upper);
|
||||
|
||||
// For fast path we always dequantize from f16 inside the shader
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// src1 type (vector)
|
||||
switch (context.src1->type) {
|
||||
case GGML_TYPE_F32:
|
||||
|
|
@ -763,39 +803,21 @@ class ggml_webgpu_shader_lib {
|
|||
GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
|
||||
}
|
||||
|
||||
// src0 type (matrix row)
|
||||
switch (context.src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("SRC0_INNER_TYPE=f32");
|
||||
defines.push_back("MUL_ACC_FLOAT");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
defines.push_back("MUL_ACC_FLOAT");
|
||||
break;
|
||||
default:
|
||||
{
|
||||
// Quantized types: use helpers but accumulate in f16
|
||||
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
||||
std::string src0_name = src0_traits->type_name;
|
||||
std::string type_upper = src0_name;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back("MUL_ACC_" + type_upper);
|
||||
|
||||
// For fast path we always dequantize from f16 inside the shader
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// VEC/SCALAR controls
|
||||
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
|
||||
|
||||
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
||||
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_TILE_K;
|
||||
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
|
||||
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
|
||||
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
|
||||
|
||||
if (key.src0_type >= GGML_TYPE_Q2_K) {
|
||||
tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;
|
||||
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
|
||||
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
|
||||
tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
|
||||
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
|
||||
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
|
||||
|
|
@ -1061,10 +1083,10 @@ class ggml_webgpu_shader_lib {
|
|||
|
||||
webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_binary_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.op = context.dst->op,
|
||||
.inplace = context.inplace,
|
||||
.overlap = context.overlap,
|
||||
.type = context.dst->type,
|
||||
.op = context.dst->op,
|
||||
.inplace = context.inplace,
|
||||
.overlap = context.overlap,
|
||||
.src_overlap = context.src_overlap,
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-webgpu-shader-lib.hpp"
|
||||
#include "pre_wgsl.hpp"
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
# include <emscripten/emscripten.h>
|
||||
|
|
@ -20,12 +19,18 @@
|
|||
#include <condition_variable>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
# include <iomanip>
|
||||
#endif
|
||||
#if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE)
|
||||
# include <iostream>
|
||||
#endif
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
|
||||
|
|
@ -70,22 +75,21 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
|
|||
#endif // GGML_WEBGPU_CPU_PROFILE
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24
|
||||
# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 32
|
||||
# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps
|
||||
#endif
|
||||
|
||||
/* Constants */
|
||||
|
||||
#define WEBGPU_NUM_PARAM_BUFS 48u
|
||||
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16u
|
||||
#define WEBGPU_NUM_PARAM_BUFS 96u
|
||||
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u
|
||||
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
|
||||
// Maximum number of in-flight submissions per-thread, to avoid exhausting the
|
||||
// parameter buffer pool
|
||||
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
|
||||
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE)
|
||||
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
|
||||
#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16
|
||||
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
|
||||
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
|
||||
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
|
||||
|
||||
// For operations which process a row in parallel, this seems like a reasonable
|
||||
// default
|
||||
|
|
@ -118,14 +122,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
|||
wgpu::BufferUsage usage,
|
||||
const char * label);
|
||||
|
||||
struct webgpu_pool_bufs {
|
||||
wgpu::Buffer host_buf;
|
||||
wgpu::Buffer dev_buf;
|
||||
};
|
||||
|
||||
// Holds a pool of parameter buffers for WebGPU operations
|
||||
struct webgpu_buf_pool {
|
||||
std::vector<webgpu_pool_bufs> free;
|
||||
std::vector<wgpu::Buffer> free;
|
||||
|
||||
// The pool must be synchronized because
|
||||
// 1. The memset pool is shared globally by every ggml buffer,
|
||||
|
|
@ -138,7 +137,6 @@ struct webgpu_buf_pool {
|
|||
size_t cur_pool_size;
|
||||
size_t max_pool_size;
|
||||
wgpu::Device device;
|
||||
wgpu::BufferUsage host_buf_usage;
|
||||
wgpu::BufferUsage dev_buf_usage;
|
||||
size_t buf_size;
|
||||
bool should_grow;
|
||||
|
|
@ -147,53 +145,47 @@ struct webgpu_buf_pool {
|
|||
int num_bufs,
|
||||
size_t buf_size,
|
||||
wgpu::BufferUsage dev_buf_usage,
|
||||
wgpu::BufferUsage host_buf_usage,
|
||||
bool should_grow = false,
|
||||
size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) {
|
||||
this->max_pool_size = max_pool_size;
|
||||
this->cur_pool_size = num_bufs;
|
||||
this->device = device;
|
||||
this->host_buf_usage = host_buf_usage;
|
||||
this->dev_buf_usage = dev_buf_usage;
|
||||
this->buf_size = buf_size;
|
||||
this->should_grow = should_grow;
|
||||
this->max_pool_size = max_pool_size;
|
||||
this->cur_pool_size = num_bufs;
|
||||
this->device = device;
|
||||
this->dev_buf_usage = dev_buf_usage;
|
||||
this->buf_size = buf_size;
|
||||
this->should_grow = should_grow;
|
||||
for (int i = 0; i < num_bufs; i++) {
|
||||
wgpu::Buffer host_buf;
|
||||
wgpu::Buffer dev_buf;
|
||||
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
|
||||
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
|
||||
free.push_back({ host_buf, dev_buf });
|
||||
free.push_back(dev_buf);
|
||||
}
|
||||
}
|
||||
|
||||
webgpu_pool_bufs alloc_bufs() {
|
||||
wgpu::Buffer alloc_bufs() {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
if (!free.empty()) {
|
||||
webgpu_pool_bufs bufs = free.back();
|
||||
wgpu::Buffer buf = free.back();
|
||||
free.pop_back();
|
||||
return bufs;
|
||||
return buf;
|
||||
}
|
||||
|
||||
// Try growing the pool if no free buffers
|
||||
if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
|
||||
cur_pool_size++;
|
||||
wgpu::Buffer host_buf;
|
||||
wgpu::Buffer dev_buf;
|
||||
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
|
||||
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
|
||||
|
||||
if (!(host_buf && dev_buf)) {
|
||||
if (!dev_buf) {
|
||||
GGML_ABORT("webgpu_buf_pool: failed to allocate buffers");
|
||||
}
|
||||
return webgpu_pool_bufs{ host_buf, dev_buf };
|
||||
return dev_buf;
|
||||
}
|
||||
cv.wait(lock, [this] { return !free.empty(); });
|
||||
webgpu_pool_bufs bufs = free.back();
|
||||
wgpu::Buffer buf = free.back();
|
||||
free.pop_back();
|
||||
return bufs;
|
||||
return buf;
|
||||
}
|
||||
|
||||
void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
|
||||
void free_bufs(std::vector<wgpu::Buffer> bufs) {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
free.insert(free.end(), bufs.begin(), bufs.end());
|
||||
cv.notify_all();
|
||||
|
|
@ -201,12 +193,9 @@ struct webgpu_buf_pool {
|
|||
|
||||
void cleanup() {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
for (auto & bufs : free) {
|
||||
if (bufs.host_buf) {
|
||||
bufs.host_buf.Destroy();
|
||||
}
|
||||
if (bufs.dev_buf) {
|
||||
bufs.dev_buf.Destroy();
|
||||
for (auto & buf : free) {
|
||||
if (buf) {
|
||||
buf.Destroy();
|
||||
}
|
||||
}
|
||||
free.clear();
|
||||
|
|
@ -280,10 +269,9 @@ struct webgpu_gpu_profile_buf_pool {
|
|||
#endif
|
||||
|
||||
struct webgpu_command {
|
||||
uint32_t num_kernels;
|
||||
wgpu::CommandBuffer commands;
|
||||
std::vector<webgpu_pool_bufs> params_bufs;
|
||||
std::optional<webgpu_pool_bufs> set_rows_error_bufs;
|
||||
uint32_t num_kernels;
|
||||
wgpu::CommandBuffer commands;
|
||||
std::vector<wgpu::Buffer> params_bufs;
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
webgpu_gpu_profile_bufs timestamp_query_bufs;
|
||||
std::string pipeline_name;
|
||||
|
|
@ -358,6 +346,13 @@ struct webgpu_global_context_struct {
|
|||
|
||||
typedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context;
|
||||
|
||||
struct webgpu_submission {
|
||||
wgpu::FutureWaitInfo submit_done;
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
std::vector<wgpu::FutureWaitInfo> profile_futures;
|
||||
#endif
|
||||
};
|
||||
|
||||
// All the base objects needed to run operations on a WebGPU device
|
||||
struct webgpu_context_struct {
|
||||
// Points to global instances owned by ggml_backend_webgpu_reg_context
|
||||
|
|
@ -366,7 +361,8 @@ struct webgpu_context_struct {
|
|||
std::unique_ptr<ggml_webgpu_shader_lib> shader_lib;
|
||||
|
||||
webgpu_buf_pool param_buf_pool;
|
||||
webgpu_buf_pool set_rows_error_buf_pool;
|
||||
wgpu::Buffer set_rows_dev_error_buf;
|
||||
wgpu::Buffer set_rows_host_error_buf;
|
||||
|
||||
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
|
||||
|
||||
|
|
@ -458,67 +454,105 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
|||
/** End WebGPU object initializations */
|
||||
|
||||
/** WebGPU Actions */
|
||||
static void erase_completed(std::vector<wgpu::FutureWaitInfo> & futures) {
|
||||
|
||||
static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) {
|
||||
switch (status) {
|
||||
case wgpu::WaitStatus::Success:
|
||||
return true;
|
||||
case wgpu::WaitStatus::TimedOut:
|
||||
if (allow_timeout) {
|
||||
return false;
|
||||
}
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n");
|
||||
return false;
|
||||
case wgpu::WaitStatus::Error:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
|
||||
return false;
|
||||
default:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
static void ggml_backend_webgpu_erase_completed_futures(std::vector<wgpu::FutureWaitInfo> & futures) {
|
||||
futures.erase(std::remove_if(futures.begin(), futures.end(),
|
||||
[](const wgpu::FutureWaitInfo & info) { return info.completed; }),
|
||||
futures.end());
|
||||
}
|
||||
|
||||
// Wait for the queue to finish processing all submitted work
|
||||
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
|
||||
std::vector<wgpu::FutureWaitInfo> & futures,
|
||||
bool block = true) {
|
||||
// If we have too many in-flight submissions, wait on the oldest one first.
|
||||
static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx,
|
||||
std::vector<wgpu::FutureWaitInfo> & futures,
|
||||
bool block) {
|
||||
if (futures.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t timeout_ms = block ? UINT64_MAX : 0;
|
||||
while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
|
||||
auto waitStatus = ctx->instance.WaitAny(1, &futures[0], UINT64_MAX);
|
||||
if (waitStatus == wgpu::WaitStatus::Error) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
|
||||
if (block) {
|
||||
while (!futures.empty()) {
|
||||
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
|
||||
if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
|
||||
ggml_backend_webgpu_erase_completed_futures(futures);
|
||||
}
|
||||
}
|
||||
if (futures[0].completed) {
|
||||
futures.erase(futures.begin());
|
||||
} else {
|
||||
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
|
||||
if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
|
||||
ggml_backend_webgpu_erase_completed_futures(futures);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// Wait for the queue to finish processing all submitted work
|
||||
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
|
||||
std::vector<webgpu_submission> & subs,
|
||||
bool block = true) {
|
||||
// If we have too many in-flight submissions, wait on the oldest one first.
|
||||
if (subs.empty()) {
|
||||
return;
|
||||
}
|
||||
while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
|
||||
auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX);
|
||||
if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);
|
||||
#endif
|
||||
subs.erase(subs.begin());
|
||||
}
|
||||
}
|
||||
|
||||
if (futures.empty()) {
|
||||
if (subs.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (block) {
|
||||
while (!futures.empty()) {
|
||||
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
|
||||
switch (waitStatus) {
|
||||
case wgpu::WaitStatus::Success:
|
||||
// WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
|
||||
erase_completed(futures);
|
||||
break;
|
||||
case wgpu::WaitStatus::Error:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
|
||||
break;
|
||||
default:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
|
||||
break;
|
||||
for (auto & sub : subs) {
|
||||
while (!sub.submit_done.completed) {
|
||||
auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX);
|
||||
ggml_backend_webgpu_handle_wait_status(waitStatus);
|
||||
}
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true);
|
||||
#endif
|
||||
}
|
||||
subs.clear();
|
||||
} else {
|
||||
// Poll once and return
|
||||
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
|
||||
switch (waitStatus) {
|
||||
case wgpu::WaitStatus::Success:
|
||||
// WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
|
||||
erase_completed(futures);
|
||||
break;
|
||||
case wgpu::WaitStatus::TimedOut:
|
||||
break;
|
||||
case wgpu::WaitStatus::Error:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
|
||||
break;
|
||||
default:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
|
||||
break;
|
||||
// Poll each submit future once and remove completed submissions.
|
||||
for (auto sub = subs.begin(); sub != subs.end();) {
|
||||
auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0);
|
||||
ggml_backend_webgpu_handle_wait_status(waitStatus, true);
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false);
|
||||
if (sub->submit_done.completed && sub->profile_futures.empty()) {
|
||||
#else
|
||||
if (sub->submit_done.completed) {
|
||||
#endif
|
||||
sub = subs.erase(sub);
|
||||
} else {
|
||||
++sub;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -554,14 +588,12 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
|
|||
}
|
||||
#endif
|
||||
|
||||
static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
|
||||
webgpu_global_context ctx,
|
||||
std::vector<webgpu_command> commands,
|
||||
webgpu_buf_pool & param_buf_pool,
|
||||
webgpu_buf_pool * set_rows_error_buf_pool = nullptr) {
|
||||
static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & ctx,
|
||||
std::vector<webgpu_command> & commands,
|
||||
webgpu_buf_pool & param_buf_pool) {
|
||||
std::vector<wgpu::CommandBuffer> command_buffers;
|
||||
std::vector<webgpu_pool_bufs> params_bufs;
|
||||
std::vector<webgpu_pool_bufs> set_rows_error_bufs;
|
||||
std::vector<wgpu::Buffer> params_bufs;
|
||||
webgpu_submission submission;
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;
|
||||
#endif
|
||||
|
|
@ -569,14 +601,9 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
|
|||
for (const auto & command : commands) {
|
||||
command_buffers.push_back(command.commands);
|
||||
params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end());
|
||||
if (command.set_rows_error_bufs) {
|
||||
set_rows_error_bufs.push_back(command.set_rows_error_bufs.value());
|
||||
}
|
||||
}
|
||||
ctx->queue.Submit(command_buffers.size(), command_buffers.data());
|
||||
|
||||
std::vector<wgpu::FutureWaitInfo> futures;
|
||||
|
||||
wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[¶m_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
||||
|
|
@ -586,27 +613,7 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
|
|||
// Free the staged buffers
|
||||
param_buf_pool.free_bufs(params_bufs);
|
||||
});
|
||||
futures.push_back({ p_f });
|
||||
|
||||
for (const auto & bufs : set_rows_error_bufs) {
|
||||
wgpu::Future f = bufs.host_buf.MapAsync(
|
||||
wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
|
||||
[set_rows_error_buf_pool, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
||||
if (status != wgpu::MapAsyncStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
|
||||
} else {
|
||||
const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange();
|
||||
if (*error_data) {
|
||||
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
|
||||
}
|
||||
// We can't unmap in here due to WebGPU reentrancy limitations.
|
||||
if (set_rows_error_buf_pool) {
|
||||
set_rows_error_buf_pool->free_bufs({ bufs });
|
||||
}
|
||||
}
|
||||
});
|
||||
futures.push_back({ f });
|
||||
}
|
||||
submission.submit_done = { p_f };
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
for (const auto & command : commands) {
|
||||
|
|
@ -623,14 +630,14 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
|
|||
// WebGPU timestamps are in ns; convert to ms
|
||||
double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
|
||||
ctx->shader_gpu_time_ms[label] += elapsed_ms;
|
||||
// We can't unmap in here due to WebGPU reentrancy limitations.
|
||||
ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
|
||||
}
|
||||
// We can't unmap in here due to WebGPU reentrancy limitations.
|
||||
ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
|
||||
});
|
||||
futures.push_back({ f });
|
||||
submission.profile_futures.push_back({ f });
|
||||
}
|
||||
#endif
|
||||
return futures;
|
||||
return submission;
|
||||
}
|
||||
|
||||
static webgpu_command ggml_backend_webgpu_build_multi(
|
||||
|
|
@ -639,32 +646,21 @@ static webgpu_command ggml_backend_webgpu_build_multi(
|
|||
const std::vector<webgpu_pipeline> & pipelines,
|
||||
const std::vector<std::vector<uint32_t>> & params_list,
|
||||
const std::vector<std::vector<wgpu::BindGroupEntry>> & bind_group_entries_list,
|
||||
const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list,
|
||||
const std::optional<webgpu_pool_bufs> & set_rows_error_bufs = std::nullopt) {
|
||||
const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list) {
|
||||
GGML_ASSERT(pipelines.size() == params_list.size());
|
||||
GGML_ASSERT(pipelines.size() == bind_group_entries_list.size());
|
||||
GGML_ASSERT(pipelines.size() == workgroups_list.size());
|
||||
|
||||
std::vector<webgpu_pool_bufs> params_bufs_list;
|
||||
std::vector<wgpu::BindGroup> bind_groups;
|
||||
std::vector<wgpu::Buffer> params_bufs_list;
|
||||
std::vector<wgpu::BindGroup> bind_groups;
|
||||
|
||||
for (size_t i = 0; i < pipelines.size(); i++) {
|
||||
webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs();
|
||||
|
||||
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0,
|
||||
params_bufs.host_buf.GetSize());
|
||||
uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
|
||||
for (size_t j = 0; j < params_list[i].size(); j++) {
|
||||
_params[j] = params_list[i][j];
|
||||
}
|
||||
params_bufs.host_buf.Unmap();
|
||||
wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs();
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = bind_group_entries_list[i];
|
||||
uint32_t params_binding_num = entries.size();
|
||||
entries.push_back({ .binding = params_binding_num,
|
||||
.buffer = params_bufs.dev_buf,
|
||||
.offset = 0,
|
||||
.size = params_bufs.dev_buf.GetSize() });
|
||||
entries.push_back(
|
||||
{ .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() });
|
||||
|
||||
wgpu::BindGroupDescriptor bind_group_desc;
|
||||
bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0);
|
||||
|
|
@ -677,15 +673,8 @@ static webgpu_command ggml_backend_webgpu_build_multi(
|
|||
}
|
||||
|
||||
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
|
||||
for (const auto & params_bufs : params_bufs_list) {
|
||||
encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
|
||||
}
|
||||
|
||||
// If there are SET_ROWS operations in this submission, copy their error
|
||||
// buffers to the host.
|
||||
if (set_rows_error_bufs) {
|
||||
encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
|
||||
set_rows_error_bufs->host_buf.GetSize());
|
||||
for (size_t i = 0; i < params_bufs_list.size(); i++) {
|
||||
ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
|
|
@ -718,7 +707,6 @@ static webgpu_command ggml_backend_webgpu_build_multi(
|
|||
webgpu_command result = {};
|
||||
result.commands = commands;
|
||||
result.params_bufs = params_bufs_list;
|
||||
result.set_rows_error_bufs = set_rows_error_bufs;
|
||||
result.num_kernels = pipelines.size();
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
result.timestamp_query_bufs = ts_bufs;
|
||||
|
|
@ -734,13 +722,13 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_global_context &
|
|||
std::vector<uint32_t> params,
|
||||
std::vector<wgpu::BindGroupEntry> bind_group_entries,
|
||||
uint32_t wg_x,
|
||||
uint32_t wg_y = 1,
|
||||
std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) {
|
||||
uint32_t wg_y = 1) {
|
||||
return ggml_backend_webgpu_build_multi(ctx, param_buf_pool,
|
||||
{
|
||||
pipeline
|
||||
},
|
||||
{ params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs);
|
||||
{ std::move(params) }, { std::move(bind_group_entries) },
|
||||
{ { wg_x, wg_y } });
|
||||
}
|
||||
|
||||
static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
|
||||
|
|
@ -757,8 +745,9 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
|
|||
|
||||
webgpu_command command =
|
||||
ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
|
||||
auto futures = ggml_backend_webgpu_submit(ctx, { command }, ctx->memset_buf_pool);
|
||||
ggml_backend_webgpu_wait(ctx, futures);
|
||||
std::vector<webgpu_command> commands = { command };
|
||||
std::vector<webgpu_submission> sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) };
|
||||
ggml_backend_webgpu_wait(ctx, sub);
|
||||
}
|
||||
|
||||
/** End WebGPU Actions */
|
||||
|
|
@ -805,7 +794,8 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
|
|||
std::cout << "\nggml_webgpu: gpu breakdown:\n";
|
||||
for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
|
||||
double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
|
||||
std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
|
||||
std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2)
|
||||
<< pct << "%)\n";
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -978,14 +968,6 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
|||
|
||||
auto * decisions = static_cast<ggml_webgpu_set_rows_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
std::optional<webgpu_pool_bufs> error_bufs = std::nullopt;
|
||||
if (decisions->i64_idx) {
|
||||
error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
|
||||
if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
|
||||
error_bufs->host_buf.Unmap();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
|
||||
|
|
@ -1018,8 +1000,10 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
|||
};
|
||||
|
||||
if (decisions->i64_idx) {
|
||||
entries.push_back(
|
||||
{ .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() });
|
||||
entries.push_back({ .binding = 3,
|
||||
.buffer = ctx->set_rows_dev_error_buf,
|
||||
.offset = 0,
|
||||
.size = ctx->set_rows_dev_error_buf.GetSize() });
|
||||
}
|
||||
|
||||
uint32_t threads;
|
||||
|
|
@ -1029,8 +1013,7 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
|||
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
|
||||
}
|
||||
uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1,
|
||||
error_bufs);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1);
|
||||
}
|
||||
|
||||
// Workgroup size is a common constant
|
||||
|
|
@ -1108,12 +1091,26 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|||
use_fast = (src0->type == GGML_TYPE_F16);
|
||||
break;
|
||||
case GGML_TYPE_F32:
|
||||
// TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q8_1:
|
||||
case GGML_TYPE_Q6_K:
|
||||
use_fast = true;
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
// we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat
|
||||
use_fast = !is_vec;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
@ -1187,17 +1184,18 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
if (use_fast && is_vec) {
|
||||
auto decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
|
||||
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t batches = dst->ne[2] * dst->ne[3];
|
||||
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
|
||||
uint32_t total_wg = output_groups * batches;
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
} else if (use_fast) {
|
||||
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
|
||||
auto * decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
// Fast-path tiled/subgroup calculations
|
||||
uint32_t wg_m, wg_n;
|
||||
uint32_t wg_m;
|
||||
uint32_t wg_n;
|
||||
if (decisions->use_subgroup_matrix) {
|
||||
uint32_t wg_m_sg_tile =
|
||||
decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m;
|
||||
|
|
@ -1215,7 +1213,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
|
||||
} else { // legacy
|
||||
auto decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
uint32_t wg_size = decisions->wg_size;
|
||||
uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
|
|
@ -1514,10 +1512,10 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
|
|||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
uint32_t dim = (uint32_t) dst->op_params[0];
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
|
|
@ -1538,28 +1536,22 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
|
|||
(uint32_t) dst->ne[2],
|
||||
(uint32_t) dst->ne[3],
|
||||
dim,
|
||||
(uint32_t)src0->ne[dim]
|
||||
(uint32_t) src0->ne[dim]
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{
|
||||
.binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0)
|
||||
},
|
||||
{
|
||||
.binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1)
|
||||
},
|
||||
{
|
||||
.binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst)
|
||||
}
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
|
|
@ -1569,9 +1561,9 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
|
|||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
|
|
@ -1623,7 +1615,12 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
|
|||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
|
|
@ -2172,19 +2169,12 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
|||
case GGML_OP_SOFT_MAX:
|
||||
return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
|
||||
case GGML_OP_UNARY:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_CLAMP:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_FILL:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_LOG:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_SQR:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_SQRT:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_SIN:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_COS:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_PAD:
|
||||
|
|
@ -2192,7 +2182,6 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
|||
case GGML_OP_ARGMAX:
|
||||
return ggml_webgpu_argmax(ctx, src0, node);
|
||||
case GGML_OP_ARGSORT:
|
||||
return ggml_webgpu_argsort(ctx, src0, node);
|
||||
case GGML_OP_TOP_K:
|
||||
// we reuse the same argsort implementation for top_k
|
||||
return ggml_webgpu_argsort(ctx, src0, node);
|
||||
|
|
@ -2214,33 +2203,51 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
|
|||
|
||||
WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
|
||||
|
||||
std::vector<webgpu_command> commands;
|
||||
std::vector<wgpu::FutureWaitInfo> futures;
|
||||
uint32_t num_batched_kernels = 0;
|
||||
std::vector<webgpu_command> commands;
|
||||
std::vector<webgpu_submission> subs;
|
||||
uint32_t num_batched_kernels = 0;
|
||||
bool contains_set_rows = false;
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
|
||||
contains_set_rows = true;
|
||||
}
|
||||
if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
|
||||
commands.push_back(*cmd);
|
||||
num_batched_kernels += cmd.value().num_kernels;
|
||||
}
|
||||
|
||||
if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
|
||||
num_batched_kernels = 0;
|
||||
std::vector<wgpu::FutureWaitInfo> compute_futures = ggml_backend_webgpu_submit(
|
||||
ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
|
||||
futures.insert(futures.end(), compute_futures.begin(), compute_futures.end());
|
||||
num_batched_kernels = 0;
|
||||
subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
|
||||
// Process events and check for completed submissions
|
||||
ctx->global_ctx->instance.ProcessEvents();
|
||||
ggml_backend_webgpu_wait(ctx->global_ctx, futures, false);
|
||||
ggml_backend_webgpu_wait(ctx->global_ctx, subs, false);
|
||||
commands.clear();
|
||||
}
|
||||
}
|
||||
if (!commands.empty()) {
|
||||
auto new_futures =
|
||||
ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
|
||||
futures.insert(futures.end(), new_futures.begin(), new_futures.end());
|
||||
subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
|
||||
commands.clear();
|
||||
}
|
||||
|
||||
ggml_backend_webgpu_wait(ctx->global_ctx, futures);
|
||||
// If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking.
|
||||
if (contains_set_rows) {
|
||||
wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder();
|
||||
encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0,
|
||||
ctx->set_rows_host_error_buf.GetSize());
|
||||
wgpu::CommandBuffer set_rows_commands = encoder.Finish();
|
||||
ctx->global_ctx->queue.Submit(1, &set_rows_commands);
|
||||
ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0,
|
||||
ctx->set_rows_host_error_buf.GetSize());
|
||||
const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange();
|
||||
if (*error_data) {
|
||||
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
|
||||
}
|
||||
ctx->set_rows_host_error_buf.Unmap();
|
||||
}
|
||||
|
||||
ggml_backend_webgpu_wait(ctx->global_ctx, subs);
|
||||
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
|
@ -2859,10 +2866,12 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
|
|||
webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
|
||||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true);
|
||||
webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
|
||||
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
|
||||
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf,
|
||||
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf");
|
||||
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_host_error_buf,
|
||||
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
|
||||
|
||||
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ fn store_shmem(val: vec4<f16>, idx: u32) {
|
|||
shmem[idx + 2] = val.z;
|
||||
shmem[idx + 3] = val.w;
|
||||
}
|
||||
#endif
|
||||
#endif // VEC
|
||||
|
||||
#ifdef SCALAR
|
||||
#define VEC_SIZE 1
|
||||
|
|
@ -23,7 +23,7 @@ fn store_shmem(val: vec4<f16>, idx: u32) {
|
|||
fn store_shmem(val: f16, idx: u32) {
|
||||
shmem[idx] = val;
|
||||
}
|
||||
#endif
|
||||
#endif // SCALAR
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_FLOAT
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
|
|
@ -40,7 +40,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
store_shmem(SHMEM_TYPE(src0_val), elem_idx);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // INIT_SRC0_SHMEM_FLOAT
|
||||
|
||||
#ifdef INIT_SRC1_SHMEM_FLOAT
|
||||
fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
|
||||
|
|
@ -57,7 +57,7 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
|
|||
store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // INIT_SRC1_SHMEM_FLOAT
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_0
|
||||
const BLOCK_SIZE = 32u;
|
||||
|
|
@ -100,4 +100,667 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // INIT_SRC0_SHMEM_Q4_0
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_1
|
||||
const BLOCK_SIZE = 32u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let d = src0[scale_idx];
|
||||
let m = src0[scale_idx + 1u];
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 2u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_lo = f16(q_byte & 0xF) * d + m;
|
||||
let q_hi = f16((q_byte >> 4) & 0xF) * d + m;
|
||||
shmem[shmem_idx + j * 2 + k] = q_lo;
|
||||
shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q4_1
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q5_0
|
||||
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
|
||||
const BLOCK_SIZE = 32u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
// tile_k is defined as 32u, so blocks_k ends up being 1 always
|
||||
override BLOCKS_K = TILE_K / BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let d = src0[scale_idx];
|
||||
let qh0 = src0[scale_idx + 1u];
|
||||
let qh1 = src0[scale_idx + 2u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
|
||||
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
|
||||
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
|
||||
let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
|
||||
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = (f16((q_byte & 0xF) | qh_lo) - 16.0) * d;
|
||||
|
||||
shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight
|
||||
shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q5_0
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q5_1
|
||||
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
|
||||
const BLOCK_SIZE = 32u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
// tile_k is defined as 32u, so blocks_k ends up being 1 always
|
||||
override BLOCKS_K = TILE_K / BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let d = src0[scale_idx];
|
||||
let m = src0[scale_idx + 1u];
|
||||
let qh0 = src0[scale_idx + 2u];
|
||||
let qh1 = src0[scale_idx + 3u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
|
||||
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
|
||||
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
|
||||
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
|
||||
let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi)) * d + m;
|
||||
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = (f16((q_byte & 0xF) | qh_lo)) * d + m;
|
||||
|
||||
shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight
|
||||
shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q5_1
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q8_0
|
||||
const BLOCK_SIZE = 32u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights
|
||||
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let d = src0[scale_idx];
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
|
||||
let q_0 = src0[scale_idx + 1u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
|
||||
let q_val = f16(q_byte) * d;
|
||||
shmem[shmem_idx + j * 2 + k] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q8_0
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q8_1
|
||||
const BLOCK_SIZE = 32u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights
|
||||
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let d = src0[scale_idx];
|
||||
let m = src0[scale_idx + 1u];
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
|
||||
let q_0 = src0[scale_idx + 2u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
|
||||
let q_val = f16(q_byte) * d + m;
|
||||
shmem[shmem_idx + j * 2 + k] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q8_1
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q2_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 42u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
// Use standard thread layout instead of lane/row_group
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let d = src0[scale_idx + 40u];
|
||||
let dmin = src0[scale_idx + 41u];
|
||||
|
||||
// Decode the element at position k_in_block
|
||||
let block_of_32 = k_in_block / 32u;
|
||||
let pos_in_32 = k_in_block % 32u;
|
||||
|
||||
let q_b_idx = (block_of_32 / 4u) * 32u;
|
||||
let shift = (block_of_32 % 4u) * 2u;
|
||||
let k = (pos_in_32 / 16u) * 16u;
|
||||
let l = pos_in_32 % 16u;
|
||||
|
||||
let is = k_in_block / 16u;
|
||||
|
||||
let sc_0 = src0[scale_idx + 2u * (is / 4u)];
|
||||
let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u];
|
||||
let sc_packed = bitcast<u32>(vec2(sc_0, sc_1));
|
||||
let sc = get_byte(sc_packed, is % 4u);
|
||||
|
||||
let dl = d * f16(sc & 0xFu);
|
||||
let ml = dmin * f16(sc >> 4u);
|
||||
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
|
||||
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
let qs_val = (q_byte >> shift) & 3u;
|
||||
|
||||
let q_val = f16(qs_val) * dl - ml;
|
||||
shmem[elem_idx] = q_val;
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q2_K
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q3_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 55u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let d = src0[scale_idx + 54u];
|
||||
|
||||
// Load and unpack scales
|
||||
let kmask1: u32 = 0x03030303u;
|
||||
let kmask2: u32 = 0x0f0f0f0fu;
|
||||
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0u; i < 4u; i++) {
|
||||
let scale_0 = src0[scale_idx + 48u + (2u*i)];
|
||||
let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u];
|
||||
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
|
||||
}
|
||||
|
||||
var tmp: u32 = scale_vals[2];
|
||||
scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u);
|
||||
scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u);
|
||||
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u);
|
||||
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u);
|
||||
|
||||
// Load hmask and qs arrays
|
||||
var hmask_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0u; i < 8u; i++) {
|
||||
let hmask_0 = src0[scale_idx + (2u*i)];
|
||||
let hmask_1 = src0[scale_idx + (2u*i) + 1u];
|
||||
hmask_vals[i] = bitcast<u32>(vec2(hmask_0, hmask_1));
|
||||
}
|
||||
|
||||
var qs_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0u; i < 16u; i++) {
|
||||
let qs_0 = src0[scale_idx + 16u + (2u*i)];
|
||||
let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u];
|
||||
qs_vals[i] = bitcast<u32>(vec2(qs_0, qs_1));
|
||||
}
|
||||
|
||||
let half = k_in_block / 128u; // 0 or 1
|
||||
let pos_in_half = k_in_block % 128u; // 0-127
|
||||
let shift_group = pos_in_half / 32u; // 0-3
|
||||
let pos_in_32 = pos_in_half % 32u; // 0-31
|
||||
let k_group = pos_in_32 / 16u; // 0 or 1
|
||||
let l = pos_in_32 % 16u; // 0-15
|
||||
|
||||
let q_b_idx = half * 32u; // 0 or 32
|
||||
let shift = shift_group * 2u; // 0, 2, 4, 6
|
||||
let k = k_group * 16u; // 0 or 16
|
||||
let is = k_in_block / 16u; // 0-15
|
||||
|
||||
// m increments every 32 elements across entire 256 element block
|
||||
let m_shift = k_in_block / 32u; // 0-7
|
||||
let m: u32 = 1u << m_shift; // 1,2,4,8,16,32,64,128
|
||||
|
||||
let sc = get_byte(scale_vals[is / 4u], is % 4u);
|
||||
let dl = d * (f16(sc) - 32.0);
|
||||
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let hm_idx = k + l;
|
||||
|
||||
let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u);
|
||||
let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u);
|
||||
|
||||
let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
|
||||
let qs_val = (q_byte >> shift) & 3u;
|
||||
|
||||
let q_val = (f16(qs_val) - f16(hm)) * dl;
|
||||
shmem[elem_idx] = q_val;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // INIT_SRC0_SHMEM_Q3_K
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 72u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let d = src0[scale_idx];
|
||||
let dmin = src0[scale_idx + 1u];
|
||||
|
||||
// Load packed scales
|
||||
var scale_vals: array<u32, 3>;
|
||||
for (var i: u32 = 0u; i < 3u; i++) {
|
||||
let scale_0 = src0[scale_idx + 2u + (2u*i)];
|
||||
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
|
||||
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
|
||||
}
|
||||
|
||||
// Map k_in_block to loop structure:
|
||||
// Outer loop over 64-element groups (alternating q_b_idx)
|
||||
// Inner loop over 2 shifts per group
|
||||
let group_of_64 = k_in_block / 64u; // 0-3 (maps to q_b_idx)
|
||||
let pos_in_64 = k_in_block % 64u; // 0-63
|
||||
let shift_group = pos_in_64 / 32u; // 0 or 1
|
||||
let l = pos_in_64 % 32u; // 0-31
|
||||
|
||||
let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
|
||||
let shift = shift_group * 4u; // 0 or 4
|
||||
let is = k_in_block / 32u; // 0-7
|
||||
|
||||
var sc: u32;
|
||||
var mn: u32;
|
||||
|
||||
if (is < 4u) {
|
||||
let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
|
||||
let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
|
||||
sc = sc_byte & 63u;
|
||||
mn = min_byte & 63u;
|
||||
} else {
|
||||
let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
|
||||
let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
|
||||
let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
|
||||
|
||||
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
|
||||
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
|
||||
}
|
||||
|
||||
let dl = d * f16(sc);
|
||||
let ml = dmin * f16(mn);
|
||||
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
|
||||
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
let qs_val = (q_byte >> shift) & 0xFu;
|
||||
|
||||
let q_val = f16(qs_val) * dl - ml;
|
||||
shmem[elem_idx] = q_val;
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q4_K
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q5_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 88u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let d = src0[scale_idx];
|
||||
let dmin = src0[scale_idx + 1u];
|
||||
|
||||
// Load packed scales
|
||||
var scale_vals: array<u32, 3>;
|
||||
for (var i: u32 = 0u; i < 3u; i++) {
|
||||
let scale_0 = src0[scale_idx + 2u + (2u*i)];
|
||||
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
|
||||
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
|
||||
}
|
||||
|
||||
// The original loop processes elements in groups of 64
|
||||
// Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]
|
||||
// But u increments EVERY 32 elements (after each l loop)
|
||||
let group_of_64 = k_in_block / 64u; // 0-3
|
||||
let pos_in_64 = k_in_block % 64u; // 0-63
|
||||
let shift_group = pos_in_64 / 32u; // 0 or 1
|
||||
let l = pos_in_64 % 32u; // 0-31
|
||||
|
||||
let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
|
||||
let shift = shift_group * 4u; // 0 or 4
|
||||
let is = k_in_block / 32u; // 0-7
|
||||
|
||||
// u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128)
|
||||
let u_shift = k_in_block / 32u; // 0-7
|
||||
let u: u32 = 1u << u_shift;
|
||||
|
||||
var sc: u32;
|
||||
var mn: u32;
|
||||
|
||||
if (is < 4u) {
|
||||
let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
|
||||
let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
|
||||
sc = sc_byte & 63u;
|
||||
mn = min_byte & 63u;
|
||||
} else {
|
||||
let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
|
||||
let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
|
||||
let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
|
||||
|
||||
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
|
||||
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
|
||||
}
|
||||
|
||||
let dl = d * f16(sc);
|
||||
let ml = dmin * f16(mn);
|
||||
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)];
|
||||
let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
|
||||
let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)];
|
||||
let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh_0, qh_1));
|
||||
|
||||
let qh_byte = get_byte(qh_packed, l % 4u);
|
||||
|
||||
let qs_val = (q_byte >> shift) & 0xFu;
|
||||
let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
|
||||
|
||||
let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml;
|
||||
shmem[elem_idx] = q_val;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // INIT_SRC0_SHMEM_Q5_K
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q6_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 105u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let half = k_in_block / 128u;
|
||||
let pos_in_half = k_in_block % 128u;
|
||||
let quarter = pos_in_half / 32u;
|
||||
let l = pos_in_half % 32u;
|
||||
|
||||
let ql_b_idx = half * 64u;
|
||||
let qh_b_idx = half * 32u;
|
||||
let sc_b_idx = half * 8u;
|
||||
|
||||
// Load only ql13 word needed
|
||||
let ql13_flat = ql_b_idx + l;
|
||||
let ql13_word = ql13_flat / 4u;
|
||||
let ql13 = bitcast<u32>(vec2(
|
||||
src0[scale_idx + 2u * ql13_word],
|
||||
src0[scale_idx + 2u * ql13_word + 1u]
|
||||
));
|
||||
let ql13_b = get_byte(ql13, ql13_flat % 4u);
|
||||
|
||||
// Load only ql24 word needed
|
||||
let ql24_flat = ql_b_idx + l + 32u;
|
||||
let ql24_word = ql24_flat / 4u;
|
||||
let ql24 = bitcast<u32>(vec2(
|
||||
src0[scale_idx + 2u * ql24_word],
|
||||
src0[scale_idx + 2u * ql24_word + 1u]
|
||||
));
|
||||
let ql24_b = get_byte(ql24, ql24_flat % 4u);
|
||||
|
||||
// Load only qh word needed
|
||||
let qh_flat = qh_b_idx + l;
|
||||
let qh_word = qh_flat / 4u;
|
||||
let qh = bitcast<u32>(vec2(
|
||||
src0[scale_idx + 64u + 2u * qh_word],
|
||||
src0[scale_idx + 64u + 2u * qh_word + 1u]
|
||||
));
|
||||
let qh_b = get_byte(qh, qh_flat % 4u);
|
||||
|
||||
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
|
||||
let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
|
||||
let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0);
|
||||
let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0);
|
||||
|
||||
// Load only the scale word needed
|
||||
let is = l / 16u;
|
||||
let sc_idx = sc_b_idx + is + quarter * 2u;
|
||||
let sc_word = sc_idx / 4u;
|
||||
let sc = bitcast<u32>(vec2(
|
||||
src0[scale_idx + 96u + 2u * sc_word],
|
||||
src0[scale_idx + 96u + 2u * sc_word + 1u]
|
||||
));
|
||||
let sc_val = get_byte_i32(sc, sc_idx % 4u);
|
||||
|
||||
let d = src0[scale_idx + 104u];
|
||||
|
||||
var q_val: f16;
|
||||
if (quarter == 0u) {
|
||||
q_val = q1;
|
||||
} else if (quarter == 1u) {
|
||||
q_val = q2;
|
||||
} else if (quarter == 2u) {
|
||||
q_val = q3;
|
||||
} else {
|
||||
q_val = q4;
|
||||
}
|
||||
|
||||
shmem[elem_idx] = d * f16(sc_val) * q_val;
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q6_K
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ fn get_local_m(thread_id: u32) -> u32 {
|
|||
const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
|
||||
const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
|
||||
const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
|
||||
|
||||
var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
|
||||
|
||||
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
enable f16;
|
||||
|
||||
#include "common_decls.tmpl"
|
||||
|
|
@ -84,6 +83,294 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q4_1
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 10u;
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
var local_sum = 0.0;
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let m = f32(src0[scale_idx + 1u]);
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 2u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
|
||||
let q_lo = f32(q_byte & 0xF) * d + m;
|
||||
local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];
|
||||
local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];
|
||||
}
|
||||
}
|
||||
}
|
||||
return local_sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q5_0
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 11u;
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
var local_sum = 0.0;
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let qh0 = src0[scale_idx + 1u];
|
||||
let qh1 = src0[scale_idx + 2u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
|
||||
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
|
||||
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
|
||||
let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
|
||||
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
|
||||
|
||||
local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k];
|
||||
local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16];
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
return local_sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef MUL_ACC_Q5_1
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 12u;
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
var local_sum = 0.0;
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let m = src0[scale_idx + 1u];
|
||||
let qh0 = src0[scale_idx + 2u];
|
||||
let qh1 = src0[scale_idx + 3u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
|
||||
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
|
||||
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
|
||||
let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + f32(m);
|
||||
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = f32((q_byte & 0xF) | qh_lo) * d + f32(m);
|
||||
|
||||
local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k];
|
||||
local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16];
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
return local_sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef MUL_ACC_Q8_0
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 17u;
|
||||
const WEIGHTS_PER_F16 = 2u;
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
var local_sum = 0.0;
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 1 + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d;
|
||||
local_sum += q_val * shared_vector[shmem_idx + j * 2 + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
return local_sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef MUL_ACC_Q8_1
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 18u;
|
||||
const WEIGHTS_PER_F16 = 2u;
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
var local_sum = 0.0;
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let m = src0[scale_idx + 1u];
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 2u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d + f32(m);
|
||||
local_sum += q_val * shared_vector[shmem_idx + j * 2 + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
return local_sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q6_K
|
||||
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 105u;
|
||||
|
||||
fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 {
|
||||
let aligned = byte_offset & ~3u;
|
||||
let idx = bbase + aligned / 2u;
|
||||
return bitcast<u32>(vec2(src0[idx], src0[idx + 1u]));
|
||||
}
|
||||
|
||||
fn byte_of(v: u32, b: u32) -> u32 {
|
||||
return (v >> (b * 8u)) & 0xFFu;
|
||||
}
|
||||
|
||||
fn sbyte_of(v: u32, b: u32) -> i32 {
|
||||
let raw = i32((v >> (b * 8u)) & 0xFFu);
|
||||
return select(raw, raw - 256, raw >= 128);
|
||||
}
|
||||
|
||||
fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
let tid = tig / 2u;
|
||||
let ix = tig % 2u;
|
||||
let ip = tid / 8u;
|
||||
let il = tid % 8u;
|
||||
let l0 = 4u * il;
|
||||
let is = 8u * ip + l0 / 16u;
|
||||
|
||||
let y_offset = 128u * ip + l0;
|
||||
let q_offset_l = 64u * ip + l0;
|
||||
let q_offset_h = 32u * ip + l0;
|
||||
|
||||
let nb = tile_size / BLOCK_SIZE;
|
||||
let k_block_start = k_outer / BLOCK_SIZE;
|
||||
|
||||
// Aligned scale byte position (is can be odd)
|
||||
let sc_base_byte = 192u + (is & ~3u);
|
||||
let sc_byte_pos = is & 3u;
|
||||
|
||||
var local_sum = 0.0;
|
||||
|
||||
for (var i = ix; i < nb; i += 2u) {
|
||||
let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK;
|
||||
|
||||
let d_raw = load_u32_at(bbase, 208u);
|
||||
let d = f32(bitcast<vec2<f16>>(d_raw)[0]);
|
||||
|
||||
let ql1_u32 = load_u32_at(bbase, q_offset_l);
|
||||
let ql2_u32 = load_u32_at(bbase, q_offset_l + 32u);
|
||||
let qh_u32 = load_u32_at(bbase, 128u + q_offset_h);
|
||||
let sc_u32_0 = load_u32_at(bbase, sc_base_byte);
|
||||
let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u);
|
||||
|
||||
let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
|
||||
let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);
|
||||
let sc4 = sbyte_of(sc_u32_1, sc_byte_pos);
|
||||
let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u);
|
||||
|
||||
var sums = vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
||||
|
||||
for (var l = 0u; l < 4u; l++) {
|
||||
let y_base = i * BLOCK_SIZE + y_offset + l;
|
||||
let yl0 = f32(shared_vector[y_base]);
|
||||
let yl1 = f32(shared_vector[y_base + 32u]);
|
||||
let yl2 = f32(shared_vector[y_base + 64u]);
|
||||
let yl3 = f32(shared_vector[y_base + 96u]);
|
||||
|
||||
let q1b = byte_of(ql1_u32, l);
|
||||
let q2b = byte_of(ql2_u32, l);
|
||||
let qhb = byte_of(qh_u32, l);
|
||||
|
||||
let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32);
|
||||
let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32);
|
||||
let dq2 = f32(i32((q1b >> 4u) | ((qhb & 0x30u) )) - 32);
|
||||
let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32);
|
||||
|
||||
sums[0] += yl0 * dq0;
|
||||
sums[1] += yl1 * dq1;
|
||||
sums[2] += yl2 * dq2;
|
||||
sums[3] += yl3 * dq3;
|
||||
}
|
||||
|
||||
local_sum += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) +
|
||||
sums[2] * f32(sc4) + sums[3] * f32(sc6));
|
||||
}
|
||||
|
||||
return local_sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
struct MulMatParams {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
|
|
@ -191,4 +478,3 @@ fn main(
|
|||
dst[dst_idx / VEC_SIZE] = store_val(group_base);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -293,6 +293,10 @@ class LlamaBenchData:
|
|||
for t in self.repo.tags:
|
||||
if t.name == name:
|
||||
return t.commit.hexsha[:self.build_len]
|
||||
for remote in self.repo.remotes:
|
||||
for ref in remote.refs:
|
||||
if ref.name == name or ref.remote_head == name:
|
||||
return ref.commit.hexsha[:self.build_len]
|
||||
for c in self.repo.iter_commits("--all"):
|
||||
if c.hexsha[:self.build_len] == name[:self.build_len]:
|
||||
return c.hexsha[:self.build_len]
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import os
|
|||
import sys
|
||||
import subprocess
|
||||
|
||||
HTTPLIB_VERSION = "refs/tags/v0.35.0"
|
||||
HTTPLIB_VERSION = "refs/tags/v0.37.0"
|
||||
|
||||
vendor = {
|
||||
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
|
||||
|
|
@ -15,7 +15,7 @@ vendor = {
|
|||
|
||||
# not using latest tag to avoid this issue: https://github.com/ggml-org/llama.cpp/pull/17179#discussion_r2515877926
|
||||
# "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.24/miniaudio.h": "vendor/miniaudio/miniaudio.h",
|
||||
"https://github.com/mackron/miniaudio/raw/13d161bc8d856ad61ae46b798bbeffc0f49808e8/miniaudio.h": "vendor/miniaudio/miniaudio.h",
|
||||
"https://github.com/mackron/miniaudio/raw/9634bedb5b5a2ca38c1ee7108a9358a4e233f14d/miniaudio.h": "vendor/miniaudio/miniaudio.h",
|
||||
|
||||
f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/httplib.h": "httplib.h",
|
||||
f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/split.py": "split.py",
|
||||
|
|
|
|||
|
|
@ -1087,6 +1087,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
|||
LLM_TENSOR_TOKEN_EMBD,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
LLM_TENSOR_OUTPUT,
|
||||
LLM_TENSOR_CLS_OUT,
|
||||
LLM_TENSOR_ATTN_NORM,
|
||||
LLM_TENSOR_ATTN_Q,
|
||||
LLM_TENSOR_ATTN_Q_NORM,
|
||||
|
|
|
|||
|
|
@ -601,7 +601,7 @@ const char * llama_grammar_parser::parse_sequence(
|
|||
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
||||
}
|
||||
const char * int_end = parse_int(pos);
|
||||
uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
|
||||
uint64_t min_times = std::stoull(std::string(pos, int_end - pos));
|
||||
pos = parse_space(int_end, is_nested);
|
||||
|
||||
uint64_t max_times = UINT64_MAX; // default: no max limit
|
||||
|
|
@ -614,7 +614,7 @@ const char * llama_grammar_parser::parse_sequence(
|
|||
|
||||
if (is_digit_char(*pos)) {
|
||||
const char * int_end = parse_int(pos);
|
||||
max_times = std::stoul(std::string(pos, int_end - pos));
|
||||
max_times = std::stoull(std::string(pos, int_end - pos));
|
||||
pos = parse_space(int_end, is_nested);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -250,7 +250,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|||
|
||||
const bool last = (
|
||||
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
|
||||
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
|
||||
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL)) // qwen3 reranking & embedding models use last token
|
||||
);
|
||||
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
|
|
@ -2552,7 +2552,7 @@ void llm_graph_context::build_pooling(
|
|||
}
|
||||
|
||||
// softmax for qwen3 reranker
|
||||
if (arch == LLM_ARCH_QWEN3) {
|
||||
if (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL) {
|
||||
cur = ggml_soft_max(ctx0, cur);
|
||||
}
|
||||
} break;
|
||||
|
|
|
|||
|
|
@ -863,9 +863,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
|
||||
quantize_state_impl qs(model, params);
|
||||
|
||||
// these need to be set to n_layer by default
|
||||
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
|
||||
|
||||
if (params->only_copy) {
|
||||
ftype = ml.ftype;
|
||||
}
|
||||
|
|
@ -976,6 +973,22 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
// compute tensor metadata once and cache it
|
||||
std::vector<tensor_metadata> metadata(tensors.size());
|
||||
|
||||
// initialize quantization state before preliminary loop (counters for use_more_bits)
|
||||
{
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
const auto cat = tensor_get_category(tensors[i]->tensor->name);
|
||||
if (category_is_attn_v(cat)) {
|
||||
++qs.n_attention_wv;
|
||||
}
|
||||
if (cat == tensor_category::OUTPUT) {
|
||||
qs.has_tied_embeddings = false;
|
||||
}
|
||||
metadata[i].category = cat; // save and re-use the category while we're at it
|
||||
}
|
||||
// these also need to be set to n_layer by default
|
||||
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer;
|
||||
}
|
||||
|
||||
// flag for --dry-run
|
||||
bool will_require_imatrix = false;
|
||||
|
||||
|
|
@ -988,16 +1001,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
const struct ggml_tensor * tensor = it->tensor;
|
||||
const std::string name = ggml_get_name(tensor);
|
||||
|
||||
metadata[i].category = tensor_get_category(name);
|
||||
|
||||
if (category_is_attn_v(metadata[i].category)) {
|
||||
++qs.n_attention_wv;
|
||||
}
|
||||
|
||||
if (tensor_name_match_output_weight(name.c_str())) {
|
||||
qs.has_tied_embeddings = false;
|
||||
}
|
||||
|
||||
uint16_t i_split = params->keep_split ? it->idx : 0;
|
||||
if (!ctx_outs[i_split]) {
|
||||
ctx_outs[i_split].reset(gguf_init_empty());
|
||||
|
|
|
|||
|
|
@ -168,8 +168,9 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp,
|
|||
GGML_ASSERT(n_seqs != 0);
|
||||
GGML_ASSERT(ubatch.equal_seqs());
|
||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||
GGML_ASSERT(d_inner % n_head == 0);
|
||||
GGML_ASSERT(d_inner % (n_group*d_state) == 0);
|
||||
GGML_ASSERT(d_inner % n_head == 0);
|
||||
GGML_ASSERT(d_inner % d_state == 0);
|
||||
GGML_ASSERT(d_inner % n_group == 0);
|
||||
|
||||
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
||||
|
|
|
|||
|
|
@ -149,6 +149,7 @@ endif ()
|
|||
if (NOT WIN32 OR NOT BUILD_SHARED_LIBS)
|
||||
# these tests are disabled on Windows because they use internal functions not exported with LLAMA_API (when building with shared libraries)
|
||||
llama_build_and_test(test-sampling.cpp)
|
||||
llama_build_and_test(test-reasoning-budget.cpp)
|
||||
llama_build_and_test(test-grammar-parser.cpp)
|
||||
llama_build_and_test(test-grammar-integration.cpp)
|
||||
llama_build_and_test(test-llama-grammar.cpp)
|
||||
|
|
|
|||
|
|
@ -569,6 +569,55 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"array with empty items",
|
||||
R"""({
|
||||
"type": "array",
|
||||
"items": {}
|
||||
})""",
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
item ::= object
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= "[" space (item ("," space item)*)? "]" space
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"array with empty items and prefixItems",
|
||||
R"""({
|
||||
"type": "array",
|
||||
"items": {},
|
||||
"prefixItems": { "type": "string" }
|
||||
})""",
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
item ::= object
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= "[" space (item ("," space item)*)? "]" space
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"number",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,238 @@
|
|||
#include "reasoning-budget.h"
|
||||
#include "unicode.h"
|
||||
|
||||
#include "llama.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#ifdef NDEBUG
|
||||
#undef NDEBUG
|
||||
#endif
|
||||
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// Reasoning budget sampler test helper
|
||||
// These tests use nullptr vocab which safely falls back to treating all tokens as complete
|
||||
// (The UTF-8 boundary detection logic is tested separately in test_utf8_boundary_detection)
|
||||
static void test_reasoning_budget(
|
||||
const char * test_name,
|
||||
const std::vector<llama_token> & sequence,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
common_reasoning_budget_state initial_state,
|
||||
size_t expected_force_start, // token index where forcing should start (SIZE_MAX = never)
|
||||
size_t expected_force_end // token index where forcing should end (after this, no more forcing)
|
||||
) {
|
||||
// Find the maximum token ID to ensure our vocab covers all tokens
|
||||
llama_token max_token = 0;
|
||||
for (auto t : sequence) max_token = std::max(max_token, t);
|
||||
for (auto t : start_tokens) max_token = std::max(max_token, t);
|
||||
for (auto t : end_tokens) max_token = std::max(max_token, t);
|
||||
for (auto t : forced_tokens) max_token = std::max(max_token, t);
|
||||
|
||||
// Create a minimal sampler with mock vocabulary
|
||||
// For this test, we use nullptr as vocab since we're testing state transitions
|
||||
// The UTF-8 boundary check will treat all tokens as complete (safe fallback)
|
||||
auto * sampler = common_reasoning_budget_init(
|
||||
nullptr, // vocab - not used for basic state machine tests
|
||||
start_tokens,
|
||||
end_tokens,
|
||||
forced_tokens,
|
||||
budget,
|
||||
initial_state
|
||||
);
|
||||
|
||||
// Create a test token data array for checking forcing behavior
|
||||
// Vocab size must be large enough to include all tokens (start, end, forced, sequence)
|
||||
std::vector<llama_token_data> cur;
|
||||
const size_t n_vocab = (size_t)max_token + 1;
|
||||
for (size_t i = 0; i < n_vocab; i++) {
|
||||
cur.emplace_back(llama_token_data{(llama_token)i, logf((float)(i+1)), 0.0f});
|
||||
}
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
||||
|
||||
size_t actual_force_start = SIZE_MAX;
|
||||
size_t actual_force_end = SIZE_MAX;
|
||||
|
||||
// Feed the sequence and track when forcing occurs
|
||||
for (size_t i = 0; i < sequence.size(); i++) {
|
||||
llama_sampler_accept(sampler, sequence[i]);
|
||||
|
||||
// Check if we're in forcing state by applying and seeing if logits are modified
|
||||
cur_p.selected = -1;
|
||||
for (size_t j = 0; j < cur.size(); j++) {
|
||||
cur[j].logit = logf((float)(j+1)); // reset logits
|
||||
}
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
|
||||
// Check if forcing is active (all logits except one should be -INFINITY)
|
||||
size_t finite_count = 0;
|
||||
llama_token finite_token = -1;
|
||||
for (size_t j = 0; j < cur.size(); j++) {
|
||||
if (std::isfinite(cur[j].logit)) {
|
||||
finite_count++;
|
||||
finite_token = cur[j].id;
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, " i=%zu: token=%d, finite_count=%zu, finite_token=%d\n", i, (int)sequence[i], finite_count, (int)finite_token);
|
||||
|
||||
if (finite_count == 1) {
|
||||
if (actual_force_start == SIZE_MAX) {
|
||||
actual_force_start = i;
|
||||
}
|
||||
actual_force_end = i;
|
||||
} else if (actual_force_start != SIZE_MAX && actual_force_end != SIZE_MAX) {
|
||||
// Forcing stopped
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
llama_sampler_free(sampler);
|
||||
|
||||
// Verify forcing occurred at expected positions
|
||||
if (expected_force_start == SIZE_MAX) {
|
||||
if (actual_force_start != SIZE_MAX) {
|
||||
fprintf(stderr, "Test '%s' FAILED: Expected no forcing, but forcing occurred at %zu\n", test_name, actual_force_start);
|
||||
GGML_ASSERT(false && "Expected no forcing, but forcing occurred");
|
||||
}
|
||||
} else {
|
||||
if (actual_force_start == SIZE_MAX) {
|
||||
fprintf(stderr, "Test '%s' FAILED: Expected forcing but none occurred\n", test_name);
|
||||
GGML_ASSERT(false && "Expected forcing but none occurred");
|
||||
}
|
||||
if (actual_force_start != expected_force_start) {
|
||||
fprintf(stderr, "Test '%s' FAILED: Forcing started at %zu, expected %zu\n", test_name, actual_force_start, expected_force_start);
|
||||
GGML_ASSERT(false && "Forcing started at wrong position");
|
||||
}
|
||||
}
|
||||
|
||||
if (expected_force_end != SIZE_MAX) {
|
||||
if (actual_force_end < expected_force_end) {
|
||||
fprintf(stderr, "Test '%s' FAILED: Forcing ended at %zu, expected >= %zu\n", test_name, actual_force_end, expected_force_end);
|
||||
GGML_ASSERT(false && "Forcing ended too early");
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, " Test '%s' passed (force_start=%zu, force_end=%zu)\n", test_name, actual_force_start, actual_force_end);
|
||||
(void)sequence;
|
||||
}
|
||||
|
||||
// UTF-8 boundary detection unit test
|
||||
// Tests common_utf8_is_complete() from reasoning-budget.h
|
||||
static void test_utf8_boundary_detection() {
|
||||
// Complete sequences
|
||||
GGML_ASSERT(common_utf8_is_complete("hello"));
|
||||
GGML_ASSERT(common_utf8_is_complete(""));
|
||||
GGML_ASSERT(common_utf8_is_complete("\xC2\xA0")); // complete 2-byte UTF-8 (U+00A0)
|
||||
GGML_ASSERT(common_utf8_is_complete("\xE2\x80\x9C")); // complete 3-byte UTF-8 (left double quote)
|
||||
GGML_ASSERT(common_utf8_is_complete("\xF0\x9F\x98\x80")); // complete 4-byte UTF-8 (emoji)
|
||||
GGML_ASSERT(common_utf8_is_complete("abc\xC3\xA9")); // ASCII + complete 2-byte
|
||||
|
||||
// Incomplete sequences
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xC2", 1))); // 2-byte start, missing continuation
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2\x80", 2))); // 3-byte start + 1 cont, missing 1
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2", 1))); // 3-byte start, missing 2
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F\x98", 3))); // 4-byte start + 2 cont, missing 1
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F", 2))); // 4-byte start + 1 cont, missing 2
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0", 1))); // 4-byte start, missing 3
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\x80", 1))); // orphan continuation byte
|
||||
|
||||
// Mixed: ASCII followed by start of multi-byte
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("hello\xC3", 6))); // ASCII + incomplete 2-byte
|
||||
GGML_ASSERT(common_utf8_is_complete(std::string("hello\xC3\xA9", 7))); // ASCII + complete 2-byte
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
// Reasoning budget sampler tests
|
||||
printf("Testing reasoning budget sampler... ");
|
||||
|
||||
// Test 1: Basic budget with start/end tokens - no forcing (natural end before budget exhausted)
|
||||
{
|
||||
const std::vector<llama_token> start = {100}; // start token
|
||||
const std::vector<llama_token> end = {101}; // end token
|
||||
const std::vector<llama_token> forced = {102}; // forced token (not used in this test)
|
||||
const std::vector<llama_token> sequence = {100, 50, 51, 101, 52}; // start, two tokens, end, one more
|
||||
|
||||
test_reasoning_budget("natural end before budget exhausted", sequence, start, end, forced,
|
||||
5, // budget of 5 tokens
|
||||
REASONING_BUDGET_IDLE,
|
||||
SIZE_MAX, SIZE_MAX); // no forcing expected (natural end)
|
||||
}
|
||||
|
||||
// Test 2: Budget exhausted, forcing should occur
|
||||
// Flow: i=0 accept(100)->COUNTING, i=1 accept(50)->remaining=1, i=2 accept(51)->remaining=0->FORCING
|
||||
// Forcing is active at i=2 and i=3 (when apply() is called while in FORCING state)
|
||||
// At i=4, force_pos becomes 2 which equals forced_tokens.size(), so state becomes DONE
|
||||
{
|
||||
const std::vector<llama_token> start = {100};
|
||||
const std::vector<llama_token> end = {101};
|
||||
const std::vector<llama_token> forced = {102, 101}; // forced message + end
|
||||
const std::vector<llama_token> sequence = {100, 50, 51, 52, 53}; // start + 4 tokens (budget=2)
|
||||
|
||||
test_reasoning_budget("budget exhausted forcing", sequence, start, end, forced,
|
||||
2, // budget of 2 tokens
|
||||
REASONING_BUDGET_IDLE,
|
||||
2, // forcing starts at i=2 (after accept(51) depletes budget, apply() forces)
|
||||
3); // forcing continues through i=3 (at i=4 state becomes DONE)
|
||||
}
|
||||
|
||||
// Test 3: Activate immediately with budget=0, forcing should start right away
|
||||
// Flow: Since no start token in sequence, state stays IDLE (no start/end configured means passthrough)
|
||||
// This test needs start token to be in the sequence or use activate_immediately with start token present
|
||||
{
|
||||
const std::vector<llama_token> start = {100};
|
||||
const std::vector<llama_token> end = {101};
|
||||
const std::vector<llama_token> forced = {102, 101};
|
||||
const std::vector<llama_token> sequence = {100, 50, 51, 52}; // start token first, then 3 tokens
|
||||
|
||||
test_reasoning_budget("activate immediately budget=0", sequence, start, end, forced,
|
||||
0, // budget of 0 tokens
|
||||
REASONING_BUDGET_COUNTING, // starts counting, promoted to FORCING since budget=0
|
||||
0, // forcing starts at i=0 (after accept(100), budget=0 goes straight to FORCING)
|
||||
1); // forcing continues through i=1 (at i=2 state becomes DONE)
|
||||
}
|
||||
|
||||
// Test 4: No start/end tokens configured - passthrough (no forcing)
|
||||
{
|
||||
const std::vector<llama_token> start = {};
|
||||
const std::vector<llama_token> end = {};
|
||||
const std::vector<llama_token> forced = {102};
|
||||
const std::vector<llama_token> sequence = {50, 51, 52, 53};
|
||||
|
||||
test_reasoning_budget("no start/end configured", sequence, start, end, forced,
|
||||
2, // budget
|
||||
REASONING_BUDGET_IDLE,
|
||||
SIZE_MAX, SIZE_MAX); // no forcing (no start/end configured)
|
||||
}
|
||||
|
||||
// Test 5: Activate immediately with budget > 0, count down then force
|
||||
// Flow: i=0 accept(50)->remaining=1, i=1 accept(51)->remaining=0->FORCING
|
||||
// So forcing starts at i=1 (apply after accept sees FORCING with force_pos=0)
|
||||
{
|
||||
const std::vector<llama_token> start = {100};
|
||||
const std::vector<llama_token> end = {101};
|
||||
const std::vector<llama_token> forced = {102, 101};
|
||||
const std::vector<llama_token> sequence = {50, 51, 52, 53};
|
||||
|
||||
test_reasoning_budget("activate immediately with budget", sequence, start, end, forced,
|
||||
2, // budget of 2 tokens
|
||||
REASONING_BUDGET_COUNTING,
|
||||
1, // forcing starts at i=1 (after 2 accepts deplete budget)
|
||||
2); // forcing continues through i=2
|
||||
}
|
||||
|
||||
printf("OK (5 tests passed)\n");
|
||||
|
||||
printf("Testing UTF-8 boundary detection... ");
|
||||
test_utf8_boundary_detection();
|
||||
printf("OK\n");
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -57,6 +57,8 @@ struct cli_context {
|
|||
std::vector<raw_buffer> input_files;
|
||||
task_params defaults;
|
||||
bool verbose_prompt;
|
||||
int reasoning_budget = -1;
|
||||
std::string reasoning_budget_message;
|
||||
|
||||
// thread for showing "loading" animation
|
||||
std::atomic<bool> loading_show;
|
||||
|
|
@ -73,6 +75,8 @@ struct cli_context {
|
|||
// defaults.return_progress = true; // TODO: show progress
|
||||
|
||||
verbose_prompt = params.verbose_prompt;
|
||||
reasoning_budget = params.reasoning_budget;
|
||||
reasoning_budget_message = params.reasoning_budget_message;
|
||||
}
|
||||
|
||||
std::string generate_completion(result_timings & out_timings) {
|
||||
|
|
@ -95,6 +99,24 @@ struct cli_context {
|
|||
task.params.chat_parser_params.parser.load(chat_params.parser);
|
||||
}
|
||||
|
||||
// reasoning budget sampler
|
||||
if (reasoning_budget >= 0 && !chat_params.thinking_end_tag.empty()) {
|
||||
const llama_vocab * vocab = llama_model_get_vocab(
|
||||
llama_get_model(ctx_server.get_llama_context()));
|
||||
|
||||
task.params.sampling.reasoning_budget_tokens = reasoning_budget;
|
||||
task.params.sampling.reasoning_budget_activate_immediately = chat_params.thinking_forced_open;
|
||||
|
||||
if (!chat_params.thinking_start_tag.empty()) {
|
||||
task.params.sampling.reasoning_budget_start =
|
||||
common_tokenize(vocab, chat_params.thinking_start_tag, false, true);
|
||||
}
|
||||
task.params.sampling.reasoning_budget_end =
|
||||
common_tokenize(vocab, chat_params.thinking_end_tag, false, true);
|
||||
task.params.sampling.reasoning_budget_forced =
|
||||
common_tokenize(vocab, reasoning_budget_message + chat_params.thinking_end_tag, false, true);
|
||||
}
|
||||
|
||||
rd.post_task({std::move(task)});
|
||||
}
|
||||
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -1101,6 +1101,22 @@ json oaicompat_chat_params_parse(
|
|||
llama_params["chat_parser"] = chat_params.parser;
|
||||
}
|
||||
|
||||
// Reasoning budget: pass parameters through to sampling layer
|
||||
{
|
||||
int reasoning_budget = opt.reasoning_budget;
|
||||
if (reasoning_budget == -1 && body.contains("thinking_budget_tokens")) {
|
||||
reasoning_budget = json_value(body, "thinking_budget_tokens", -1);
|
||||
}
|
||||
|
||||
if (reasoning_budget >= 0 && !chat_params.thinking_end_tag.empty()) {
|
||||
llama_params["reasoning_budget_tokens"] = reasoning_budget;
|
||||
llama_params["reasoning_budget_start_tag"] = chat_params.thinking_start_tag;
|
||||
llama_params["reasoning_budget_end_tag"] = chat_params.thinking_end_tag;
|
||||
llama_params["reasoning_budget_message"] = opt.reasoning_budget_message;
|
||||
llama_params["reasoning_budget_activate_immediately"] = chat_params.thinking_forced_open;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle "logprobs" field
|
||||
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
|
||||
if (json_value(body, "logprobs", false)) {
|
||||
|
|
|
|||
|
|
@ -287,6 +287,8 @@ struct server_chat_params {
|
|||
bool allow_image;
|
||||
bool allow_audio;
|
||||
bool enable_thinking = true;
|
||||
int reasoning_budget = -1;
|
||||
std::string reasoning_budget_message;
|
||||
std::string media_path;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -893,9 +893,10 @@ private:
|
|||
}
|
||||
|
||||
// thinking is enabled if:
|
||||
// 1. It's not explicitly disabled (reasoning_budget == 0)
|
||||
// 1. It's not explicitly disabled via --reasoning off
|
||||
// 2. The chat template supports it
|
||||
const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
|
||||
const bool template_supports_thinking = params_base.use_jinja && common_chat_templates_support_enable_thinking(chat_templates.get());
|
||||
const bool enable_thinking = params_base.enable_reasoning != 0 && template_supports_thinking;
|
||||
SRV_INF("%s: chat template, thinking = %d\n", __func__, enable_thinking);
|
||||
|
||||
chat_params = {
|
||||
|
|
@ -907,6 +908,8 @@ private:
|
|||
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
|
||||
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
|
||||
/* enable_thinking */ enable_thinking,
|
||||
/* reasoning_budget */ params_base.reasoning_budget,
|
||||
/* reasoning_budget_msg */ params_base.reasoning_budget_message,
|
||||
/* media_path */ params_base.media_path,
|
||||
};
|
||||
}
|
||||
|
|
@ -2530,9 +2533,24 @@ private:
|
|||
slot.n_prompt_tokens_processed++;
|
||||
|
||||
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
|
||||
const int n_last = std::min(n_batch, 512);
|
||||
if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) {
|
||||
break;
|
||||
// create checkpoints that many tokens before the end of the prompt:
|
||||
// - 4 + n_ubatch
|
||||
// - 4
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/20288
|
||||
{
|
||||
static const int checkpoint_offsets[] = {4 + n_ubatch, 4};
|
||||
|
||||
bool should_break = false;
|
||||
for (int offset : checkpoint_offsets) {
|
||||
const int n_last = std::min(n_batch, offset);
|
||||
if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) {
|
||||
should_break = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (should_break) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2554,18 +2572,27 @@ private:
|
|||
slot.init_sampler();
|
||||
SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
|
||||
} else {
|
||||
// only do non-end checkpoints if the "checkpoint every n tokens" option is set
|
||||
do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0;
|
||||
if (do_checkpoint) {
|
||||
llama_pos last_checkpoint = 0;
|
||||
if (!slot.prompt.checkpoints.empty()) {
|
||||
last_checkpoint = slot.prompt.checkpoints.back().n_tokens;
|
||||
}
|
||||
do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt;
|
||||
if (slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch) {
|
||||
// near the end of the prompt
|
||||
do_checkpoint = do_checkpoint && true;
|
||||
} else {
|
||||
// only do non-end checkpoints if the "checkpoint every n tokens" option is set
|
||||
do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0;
|
||||
|
||||
if (do_checkpoint) {
|
||||
SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens());
|
||||
llama_pos last_checkpoint = 0;
|
||||
if (!slot.prompt.checkpoints.empty()) {
|
||||
last_checkpoint = slot.prompt.checkpoints.back().n_tokens;
|
||||
}
|
||||
|
||||
do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt;
|
||||
|
||||
if (do_checkpoint) {
|
||||
SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -462,6 +462,34 @@ task_params server_task::params_from_json_cmpl(
|
|||
}
|
||||
}
|
||||
|
||||
// Parse reasoning budget sampler parameters
|
||||
{
|
||||
const int32_t budget = json_value(data, "reasoning_budget_tokens", (int32_t) -1);
|
||||
if (budget >= 0) {
|
||||
const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string());
|
||||
const auto end_tag = json_value(data, "reasoning_budget_end_tag", std::string());
|
||||
const auto message = json_value(data, "reasoning_budget_message", std::string());
|
||||
const bool activate_imm = json_value(data, "reasoning_budget_activate_immediately", false);
|
||||
|
||||
params.sampling.reasoning_budget_tokens = budget;
|
||||
params.sampling.reasoning_budget_activate_immediately = activate_imm;
|
||||
|
||||
if (!start_tag.empty()) {
|
||||
params.sampling.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true);
|
||||
}
|
||||
if (!end_tag.empty()) {
|
||||
params.sampling.reasoning_budget_end = common_tokenize(vocab, end_tag, false, true);
|
||||
params.sampling.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true);
|
||||
}
|
||||
|
||||
SRV_DBG("reasoning budget: tokens=%d, activate_immediately=%s, start=%zu toks, end=%zu toks, forced=%zu toks\n",
|
||||
budget, activate_imm ? "true" : "false",
|
||||
params.sampling.reasoning_budget_start.size(),
|
||||
params.sampling.reasoning_budget_end.size(),
|
||||
params.sampling.reasoning_budget_forced.size());
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
params.sampling.logit_bias.clear();
|
||||
|
||||
|
|
|
|||
|
|
@ -318,6 +318,12 @@ class AgenticStore {
|
|||
const maxTurns = agenticConfig.maxTurns;
|
||||
const maxToolPreviewLines = agenticConfig.maxToolPreviewLines;
|
||||
|
||||
// Resolve effective model for vision capability checks.
|
||||
// In ROUTER mode, options.model is always set by the caller.
|
||||
// In MODEL mode, options.model is undefined; use the single loaded model
|
||||
// which carries modalities bridged from /props.
|
||||
const effectiveModel = options.model || modelsStore.models[0]?.model || '';
|
||||
|
||||
for (let turn = 0; turn < maxTurns; turn++) {
|
||||
this.updateSession(conversationId, { currentTurn: turn + 1 });
|
||||
agenticTimings.turns = turn + 1;
|
||||
|
|
@ -571,14 +577,14 @@ class AgenticStore {
|
|||
];
|
||||
for (const attachment of attachments) {
|
||||
if (attachment.type === AttachmentType.IMAGE) {
|
||||
if (modelsStore.modelSupportsVision(options.model ?? '')) {
|
||||
if (modelsStore.modelSupportsVision(effectiveModel)) {
|
||||
contentParts.push({
|
||||
type: ContentPartType.IMAGE_URL,
|
||||
image_url: { url: (attachment as DatabaseMessageExtraImageFile).base64Url }
|
||||
});
|
||||
} else {
|
||||
console.info(
|
||||
`[AgenticStore] Skipping image attachment (model "${options.model}" does not support vision)`
|
||||
`[AgenticStore] Skipping image attachment (model "${effectiveModel}" does not support vision)`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -813,17 +813,13 @@ bool is_websocket_upgrade(const Request &req) {
|
|||
// Check Upgrade: websocket (case-insensitive)
|
||||
auto upgrade_it = req.headers.find("Upgrade");
|
||||
if (upgrade_it == req.headers.end()) { return false; }
|
||||
auto upgrade_val = upgrade_it->second;
|
||||
std::transform(upgrade_val.begin(), upgrade_val.end(), upgrade_val.begin(),
|
||||
::tolower);
|
||||
auto upgrade_val = case_ignore::to_lower(upgrade_it->second);
|
||||
if (upgrade_val != "websocket") { return false; }
|
||||
|
||||
// Check Connection header contains "Upgrade"
|
||||
auto connection_it = req.headers.find("Connection");
|
||||
if (connection_it == req.headers.end()) { return false; }
|
||||
auto connection_val = connection_it->second;
|
||||
std::transform(connection_val.begin(), connection_val.end(),
|
||||
connection_val.begin(), ::tolower);
|
||||
auto connection_val = case_ignore::to_lower(connection_it->second);
|
||||
if (connection_val.find("upgrade") == std::string::npos) { return false; }
|
||||
|
||||
// Check Sec-WebSocket-Key is a valid base64-encoded 16-byte value (24 chars)
|
||||
|
|
@ -2615,10 +2611,15 @@ bool can_compress_content_type(const std::string &content_type) {
|
|||
switch (tag) {
|
||||
case "image/svg+xml"_t:
|
||||
case "application/javascript"_t:
|
||||
case "application/x-javascript"_t:
|
||||
case "application/json"_t:
|
||||
case "application/ld+json"_t:
|
||||
case "application/xml"_t:
|
||||
case "application/protobuf"_t:
|
||||
case "application/xhtml+xml"_t: return true;
|
||||
case "application/xhtml+xml"_t:
|
||||
case "application/rss+xml"_t:
|
||||
case "application/atom+xml"_t:
|
||||
case "application/xslt+xml"_t:
|
||||
case "application/protobuf"_t: return true;
|
||||
|
||||
case "text/event-stream"_t: return false;
|
||||
|
||||
|
|
@ -3038,17 +3039,13 @@ bool read_websocket_upgrade_response(Stream &strm,
|
|||
// Verify Upgrade: websocket (case-insensitive)
|
||||
auto upgrade_it = headers.find("Upgrade");
|
||||
if (upgrade_it == headers.end()) { return false; }
|
||||
auto upgrade_val = upgrade_it->second;
|
||||
std::transform(upgrade_val.begin(), upgrade_val.end(), upgrade_val.begin(),
|
||||
::tolower);
|
||||
auto upgrade_val = case_ignore::to_lower(upgrade_it->second);
|
||||
if (upgrade_val != "websocket") { return false; }
|
||||
|
||||
// Verify Connection header contains "Upgrade" (case-insensitive)
|
||||
auto connection_it = headers.find("Connection");
|
||||
if (connection_it == headers.end()) { return false; }
|
||||
auto connection_val = connection_it->second;
|
||||
std::transform(connection_val.begin(), connection_val.end(),
|
||||
connection_val.begin(), ::tolower);
|
||||
auto connection_val = case_ignore::to_lower(connection_it->second);
|
||||
if (connection_val.find("upgrade") == std::string::npos) { return false; }
|
||||
|
||||
// Verify Sec-WebSocket-Accept header value
|
||||
|
|
@ -3934,14 +3931,10 @@ public:
|
|||
file_.content_type =
|
||||
trim_copy(header.substr(str_len(header_content_type)));
|
||||
} else {
|
||||
thread_local const std::regex re_content_disposition(
|
||||
R"~(^Content-Disposition:\s*form-data;\s*(.*)$)~",
|
||||
std::regex_constants::icase);
|
||||
|
||||
std::smatch m;
|
||||
if (std::regex_match(header, m, re_content_disposition)) {
|
||||
std::string disposition_params;
|
||||
if (parse_content_disposition(header, disposition_params)) {
|
||||
Params params;
|
||||
parse_disposition_params(m[1], params);
|
||||
parse_disposition_params(disposition_params, params);
|
||||
|
||||
auto it = params.find("name");
|
||||
if (it != params.end()) {
|
||||
|
|
@ -3956,13 +3949,14 @@ public:
|
|||
|
||||
it = params.find("filename*");
|
||||
if (it != params.end()) {
|
||||
// Only allow UTF-8 encoding...
|
||||
thread_local const std::regex re_rfc5987_encoding(
|
||||
R"~(^UTF-8''(.+?)$)~", std::regex_constants::icase);
|
||||
|
||||
std::smatch m2;
|
||||
if (std::regex_match(it->second, m2, re_rfc5987_encoding)) {
|
||||
file_.filename = decode_path_component(m2[1]); // override...
|
||||
// RFC 5987: only UTF-8 encoding is allowed
|
||||
const auto &val = it->second;
|
||||
constexpr const char utf8_prefix[] = "UTF-8''";
|
||||
constexpr size_t prefix_len = str_len(utf8_prefix);
|
||||
if (val.size() > prefix_len &&
|
||||
start_with_case_ignore(val, utf8_prefix)) {
|
||||
file_.filename = decode_path_component(
|
||||
val.substr(prefix_len)); // override...
|
||||
} else {
|
||||
is_valid_ = false;
|
||||
return false;
|
||||
|
|
@ -4030,17 +4024,48 @@ private:
|
|||
file_.headers.clear();
|
||||
}
|
||||
|
||||
bool start_with_case_ignore(const std::string &a, const char *b) const {
|
||||
bool start_with_case_ignore(const std::string &a, const char *b,
|
||||
size_t offset = 0) const {
|
||||
const auto b_len = strlen(b);
|
||||
if (a.size() < b_len) { return false; }
|
||||
if (a.size() < offset + b_len) { return false; }
|
||||
for (size_t i = 0; i < b_len; i++) {
|
||||
if (case_ignore::to_lower(a[i]) != case_ignore::to_lower(b[i])) {
|
||||
if (case_ignore::to_lower(a[offset + i]) != case_ignore::to_lower(b[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Parses "Content-Disposition: form-data; <params>" without std::regex.
|
||||
// Returns true if header matches, with the params portion in `params_out`.
|
||||
bool parse_content_disposition(const std::string &header,
|
||||
std::string ¶ms_out) const {
|
||||
constexpr const char prefix[] = "Content-Disposition:";
|
||||
constexpr size_t prefix_len = str_len(prefix);
|
||||
|
||||
if (!start_with_case_ignore(header, prefix)) { return false; }
|
||||
|
||||
// Skip whitespace after "Content-Disposition:"
|
||||
auto pos = prefix_len;
|
||||
while (pos < header.size() && (header[pos] == ' ' || header[pos] == '\t')) {
|
||||
pos++;
|
||||
}
|
||||
|
||||
// Match "form-data;" (case-insensitive)
|
||||
constexpr const char form_data[] = "form-data;";
|
||||
constexpr size_t form_data_len = str_len(form_data);
|
||||
if (!start_with_case_ignore(header, form_data, pos)) { return false; }
|
||||
pos += form_data_len;
|
||||
|
||||
// Skip whitespace after "form-data;"
|
||||
while (pos < header.size() && (header[pos] == ' ' || header[pos] == '\t')) {
|
||||
pos++;
|
||||
}
|
||||
|
||||
params_out = header.substr(pos);
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::string dash_ = "--";
|
||||
const std::string crlf_ = "\r\n";
|
||||
std::string boundary_;
|
||||
|
|
@ -4992,9 +5017,10 @@ bool match_hostname(const std::string &pattern,
|
|||
// Verify certificate using Windows CertGetCertificateChain API.
|
||||
// This provides real-time certificate validation with Windows Update
|
||||
// integration, independent of the TLS backend (OpenSSL or MbedTLS).
|
||||
bool verify_cert_with_windows_schannel(
|
||||
const std::vector<unsigned char> &der_cert, const std::string &hostname,
|
||||
bool verify_hostname, unsigned long &out_error) {
|
||||
bool
|
||||
verify_cert_with_windows_schannel(const std::vector<unsigned char> &der_cert,
|
||||
const std::string &hostname,
|
||||
bool verify_hostname, uint64_t &out_error) {
|
||||
if (der_cert.empty()) { return false; }
|
||||
|
||||
out_error = 0;
|
||||
|
|
@ -7987,7 +8013,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
|
|||
#else
|
||||
try {
|
||||
routed = routing(req, res, strm);
|
||||
} catch (std::exception &e) {
|
||||
} catch (std::exception &) {
|
||||
if (exception_handler_) {
|
||||
auto ep = std::current_exception();
|
||||
exception_handler_(req, res, ep);
|
||||
|
|
@ -11811,7 +11837,7 @@ bool SSLClient::initialize_ssl(Socket &socket, Error &error) {
|
|||
server_certificate_verification_) {
|
||||
verify_result_ = tls::get_verify_result(session);
|
||||
if (verify_result_ != 0) {
|
||||
last_backend_error_ = static_cast<unsigned long>(verify_result_);
|
||||
last_backend_error_ = static_cast<uint64_t>(verify_result_);
|
||||
error = Error::SSLServerVerification;
|
||||
output_error_log(error, nullptr);
|
||||
return false;
|
||||
|
|
@ -11850,7 +11876,7 @@ bool SSLClient::initialize_ssl(Socket &socket, Error &error) {
|
|||
ca_cert_dir_path_.empty() && ca_cert_pem_.empty()) {
|
||||
std::vector<unsigned char> der;
|
||||
if (get_cert_der(server_cert, der)) {
|
||||
unsigned long wincrypt_error = 0;
|
||||
uint64_t wincrypt_error = 0;
|
||||
if (!detail::verify_cert_with_windows_schannel(
|
||||
der, host_, server_hostname_verification_, wincrypt_error)) {
|
||||
last_backend_error_ = wincrypt_error;
|
||||
|
|
@ -11974,16 +12000,26 @@ bool is_ipv4_address(const std::string &str) {
|
|||
|
||||
// Parse IPv4 address string to bytes
|
||||
bool parse_ipv4(const std::string &str, unsigned char *out) {
|
||||
int parts[4];
|
||||
if (sscanf(str.c_str(), "%d.%d.%d.%d", &parts[0], &parts[1], &parts[2],
|
||||
&parts[3]) != 4) {
|
||||
return false;
|
||||
}
|
||||
const char *p = str.c_str();
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if (parts[i] < 0 || parts[i] > 255) return false;
|
||||
out[i] = static_cast<unsigned char>(parts[i]);
|
||||
if (i > 0) {
|
||||
if (*p != '.') { return false; }
|
||||
p++;
|
||||
}
|
||||
int val = 0;
|
||||
int digits = 0;
|
||||
while (*p >= '0' && *p <= '9') {
|
||||
val = val * 10 + (*p - '0');
|
||||
if (val > 255) { return false; }
|
||||
p++;
|
||||
digits++;
|
||||
}
|
||||
if (digits == 0) { return false; }
|
||||
// Reject leading zeros (e.g., "01.002.03.04") to prevent ambiguity
|
||||
if (digits > 1 && *(p - digits) == '0') { return false; }
|
||||
out[i] = static_cast<unsigned char>(val);
|
||||
}
|
||||
return true;
|
||||
return *p == '\0';
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
|
|
@ -13285,11 +13321,11 @@ void update_server_certs_from_x509(ctx_t ctx, X509 *cert, EVP_PKEY *key,
|
|||
|
||||
ctx_t create_client_context_from_x509(X509 *cert, EVP_PKEY *key,
|
||||
const char *password,
|
||||
unsigned long &out_error) {
|
||||
uint64_t &out_error) {
|
||||
out_error = 0;
|
||||
auto ctx = create_client_context();
|
||||
if (!ctx) {
|
||||
out_error = static_cast<unsigned long>(get_error());
|
||||
out_error = get_error();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
@ -13303,7 +13339,7 @@ ctx_t create_client_context_from_x509(X509 *cert, EVP_PKEY *key,
|
|||
}
|
||||
if (!set_client_cert_pem(ctx, cert_pem.c_str(), key_pem.c_str(),
|
||||
password)) {
|
||||
out_error = static_cast<unsigned long>(get_error());
|
||||
out_error = get_error();
|
||||
free_context(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@
|
|||
#ifndef CPPHTTPLIB_HTTPLIB_H
|
||||
#define CPPHTTPLIB_HTTPLIB_H
|
||||
|
||||
#define CPPHTTPLIB_VERSION "0.35.0"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x002300"
|
||||
#define CPPHTTPLIB_VERSION "0.37.0"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x002500"
|
||||
|
||||
/*
|
||||
* Platform compatibility check
|
||||
|
|
@ -575,6 +575,14 @@ inline unsigned char to_lower(int c) {
|
|||
return table[(unsigned char)(char)c];
|
||||
}
|
||||
|
||||
inline std::string to_lower(const std::string &s) {
|
||||
std::string result = s;
|
||||
std::transform(
|
||||
result.begin(), result.end(), result.begin(),
|
||||
[](unsigned char c) { return static_cast<char>(to_lower(c)); });
|
||||
return result;
|
||||
}
|
||||
|
||||
inline bool equal(const std::string &a, const std::string &b) {
|
||||
return a.size() == b.size() &&
|
||||
std::equal(a.begin(), a.end(), b.begin(), [](char ca, char cb) {
|
||||
|
|
@ -1859,23 +1867,23 @@ public:
|
|||
: res_(std::move(res)), err_(err),
|
||||
request_headers_(std::move(request_headers)), ssl_error_(ssl_error) {}
|
||||
Result(std::unique_ptr<Response> &&res, Error err, Headers &&request_headers,
|
||||
int ssl_error, unsigned long ssl_backend_error)
|
||||
int ssl_error, uint64_t ssl_backend_error)
|
||||
: res_(std::move(res)), err_(err),
|
||||
request_headers_(std::move(request_headers)), ssl_error_(ssl_error),
|
||||
ssl_backend_error_(ssl_backend_error) {}
|
||||
|
||||
int ssl_error() const { return ssl_error_; }
|
||||
unsigned long ssl_backend_error() const { return ssl_backend_error_; }
|
||||
uint64_t ssl_backend_error() const { return ssl_backend_error_; }
|
||||
|
||||
private:
|
||||
int ssl_error_ = 0;
|
||||
unsigned long ssl_backend_error_ = 0;
|
||||
uint64_t ssl_backend_error_ = 0;
|
||||
#endif
|
||||
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
public:
|
||||
[[deprecated("Use ssl_backend_error() instead")]]
|
||||
unsigned long ssl_openssl_error() const {
|
||||
uint64_t ssl_openssl_error() const {
|
||||
return ssl_backend_error_;
|
||||
}
|
||||
#endif
|
||||
|
|
@ -2345,7 +2353,7 @@ protected:
|
|||
bool server_hostname_verification_ = true;
|
||||
std::string ca_cert_pem_; // Store CA cert PEM for redirect transfer
|
||||
int last_ssl_error_ = 0;
|
||||
unsigned long last_backend_error_ = 0;
|
||||
uint64_t last_backend_error_ = 0;
|
||||
#endif
|
||||
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
/*
|
||||
Audio playback and capture library. Choice of public domain or MIT-0. See license statements at the end of this file.
|
||||
miniaudio - v0.11.24 - 2026-01-17
|
||||
miniaudio - v0.11.25 - 2026-03-04
|
||||
|
||||
David Reid - mackron@gmail.com
|
||||
|
||||
|
|
@ -3747,7 +3747,7 @@ extern "C" {
|
|||
|
||||
#define MA_VERSION_MAJOR 0
|
||||
#define MA_VERSION_MINOR 11
|
||||
#define MA_VERSION_REVISION 24
|
||||
#define MA_VERSION_REVISION 25
|
||||
#define MA_VERSION_STRING MA_XSTRINGIFY(MA_VERSION_MAJOR) "." MA_XSTRINGIFY(MA_VERSION_MINOR) "." MA_XSTRINGIFY(MA_VERSION_REVISION)
|
||||
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
|
|
@ -19358,7 +19358,7 @@ MA_API ma_handle ma_dlopen(ma_log* pLog, const char* filename)
|
|||
#else
|
||||
/* *sigh* It appears there is no ANSI version of LoadPackagedLibrary()... */
|
||||
WCHAR filenameW[4096];
|
||||
if (MultiByteToWideChar(CP_UTF8, 0, filename, -1, filenameW, sizeof(filenameW)) == 0) {
|
||||
if (MultiByteToWideChar(CP_UTF8, 0, filename, -1, filenameW, ma_countof(filenameW)) == 0) {
|
||||
handle = NULL;
|
||||
} else {
|
||||
handle = (ma_handle)LoadPackagedLibrary(filenameW, 0);
|
||||
|
|
@ -41495,18 +41495,37 @@ Web Audio Backend
|
|||
#ifdef MA_HAS_WEBAUDIO
|
||||
#include <emscripten/emscripten.h>
|
||||
|
||||
#if (__EMSCRIPTEN_major__ > 3) || (__EMSCRIPTEN_major__ == 3 && (__EMSCRIPTEN_minor__ > 1 || (__EMSCRIPTEN_minor__ == 1 && __EMSCRIPTEN_tiny__ >= 32)))
|
||||
#ifndef MA_EMSCRIPTEN_MAJOR
|
||||
#if defined(__EMSCRIPTEN_MAJOR__)
|
||||
#define MA_EMSCRIPTEN_MAJOR __EMSCRIPTEN_MAJOR__
|
||||
#else
|
||||
#define MA_EMSCRIPTEN_MAJOR __EMSCRIPTEN_major__
|
||||
#endif
|
||||
#endif
|
||||
#ifndef MA_EMSCRIPTEN_MINOR
|
||||
#if defined(__EMSCRIPTEN_MINOR__)
|
||||
#define MA_EMSCRIPTEN_MINOR __EMSCRIPTEN_MINOR__
|
||||
#else
|
||||
#define MA_EMSCRIPTEN_MINOR __EMSCRIPTEN_minor__
|
||||
#endif
|
||||
#endif
|
||||
#ifndef MA_EMSCRIPTEN_TINY
|
||||
#if defined(__EMSCRIPTEN_TINY__)
|
||||
#define MA_EMSCRIPTEN_TINY __EMSCRIPTEN_TINY__
|
||||
#else
|
||||
#define MA_EMSCRIPTEN_TINY __EMSCRIPTEN_tiny__
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if (MA_EMSCRIPTEN_MAJOR > 3) || (MA_EMSCRIPTEN_MAJOR == 3 && (MA_EMSCRIPTEN_MINOR > 1 || (MA_EMSCRIPTEN_MINOR == 1 && MA_EMSCRIPTEN_TINY >= 32)))
|
||||
#include <emscripten/webaudio.h>
|
||||
#define MA_SUPPORT_AUDIO_WORKLETS
|
||||
|
||||
#if (__EMSCRIPTEN_major__ > 3) || (__EMSCRIPTEN_major__ == 3 && (__EMSCRIPTEN_minor__ > 1 || (__EMSCRIPTEN_minor__ == 1 && __EMSCRIPTEN_tiny__ >= 70)))
|
||||
#if (MA_EMSCRIPTEN_MAJOR > 3) || (MA_EMSCRIPTEN_MAJOR == 3 && (MA_EMSCRIPTEN_MINOR > 1 || (MA_EMSCRIPTEN_MINOR == 1 && MA_EMSCRIPTEN_TINY >= 70)))
|
||||
#define MA_SUPPORT_AUDIO_WORKLETS_VARIABLE_BUFFER_SIZE
|
||||
#endif
|
||||
#endif
|
||||
|
||||
/*
|
||||
TODO: Version 0.12: Swap this logic around so that AudioWorklets are used by default. Add MA_NO_AUDIO_WORKLETS.
|
||||
*/
|
||||
#if defined(MA_ENABLE_AUDIO_WORKLETS) && defined(MA_SUPPORT_AUDIO_WORKLETS)
|
||||
#define MA_USE_AUDIO_WORKLETS
|
||||
#endif
|
||||
|
|
@ -59243,6 +59262,10 @@ static ma_result ma_data_source_read_pcm_frames_within_range(ma_data_source* pDa
|
|||
ma_uint64 framesRead = 0;
|
||||
ma_bool32 loop = ma_data_source_is_looping(pDataSource);
|
||||
|
||||
if (pFramesRead != NULL) {
|
||||
*pFramesRead = 0;
|
||||
}
|
||||
|
||||
if (pDataSourceBase == NULL) {
|
||||
return MA_AT_END;
|
||||
}
|
||||
|
|
@ -61921,7 +61944,7 @@ extern "C" {
|
|||
#define MA_DR_WAV_XSTRINGIFY(x) MA_DR_WAV_STRINGIFY(x)
|
||||
#define MA_DR_WAV_VERSION_MAJOR 0
|
||||
#define MA_DR_WAV_VERSION_MINOR 14
|
||||
#define MA_DR_WAV_VERSION_REVISION 4
|
||||
#define MA_DR_WAV_VERSION_REVISION 5
|
||||
#define MA_DR_WAV_VERSION_STRING MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_MAJOR) "." MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_MINOR) "." MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_REVISION)
|
||||
#include <stddef.h>
|
||||
#define MA_DR_WAVE_FORMAT_PCM 0x1
|
||||
|
|
@ -80503,6 +80526,13 @@ MA_PRIVATE ma_uint64 ma_dr_wav__read_smpl_to_metadata_obj(ma_dr_wav__metadata_pa
|
|||
MA_DR_WAV_ASSERT(pChunkHeader != NULL);
|
||||
if (pMetadata != NULL && bytesJustRead == sizeof(smplHeaderData)) {
|
||||
ma_uint32 iSampleLoop;
|
||||
ma_uint32 loopCount;
|
||||
ma_uint32 calculatedLoopCount;
|
||||
loopCount = ma_dr_wav_bytes_to_u32(smplHeaderData + 28);
|
||||
calculatedLoopCount = (pChunkHeader->sizeInBytes - MA_DR_WAV_SMPL_BYTES) / MA_DR_WAV_SMPL_LOOP_BYTES;
|
||||
if (loopCount != calculatedLoopCount) {
|
||||
return totalBytesRead;
|
||||
}
|
||||
pMetadata->type = ma_dr_wav_metadata_type_smpl;
|
||||
pMetadata->data.smpl.manufacturerId = ma_dr_wav_bytes_to_u32(smplHeaderData + 0);
|
||||
pMetadata->data.smpl.productId = ma_dr_wav_bytes_to_u32(smplHeaderData + 4);
|
||||
|
|
@ -80513,7 +80543,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav__read_smpl_to_metadata_obj(ma_dr_wav__metadata_pa
|
|||
pMetadata->data.smpl.smpteOffset = ma_dr_wav_bytes_to_u32(smplHeaderData + 24);
|
||||
pMetadata->data.smpl.sampleLoopCount = ma_dr_wav_bytes_to_u32(smplHeaderData + 28);
|
||||
pMetadata->data.smpl.samplerSpecificDataSizeInBytes = ma_dr_wav_bytes_to_u32(smplHeaderData + 32);
|
||||
if (pMetadata->data.smpl.sampleLoopCount == (pChunkHeader->sizeInBytes - MA_DR_WAV_SMPL_BYTES) / MA_DR_WAV_SMPL_LOOP_BYTES) {
|
||||
if (pMetadata->data.smpl.sampleLoopCount == calculatedLoopCount) {
|
||||
pMetadata->data.smpl.pLoops = (ma_dr_wav_smpl_loop*)ma_dr_wav__metadata_get_memory(pParser, sizeof(ma_dr_wav_smpl_loop) * pMetadata->data.smpl.sampleLoopCount, MA_DR_WAV_METADATA_ALIGNMENT);
|
||||
for (iSampleLoop = 0; iSampleLoop < pMetadata->data.smpl.sampleLoopCount; ++iSampleLoop) {
|
||||
ma_uint8 smplLoopData[MA_DR_WAV_SMPL_LOOP_BYTES];
|
||||
|
|
@ -80534,6 +80564,8 @@ MA_PRIVATE ma_uint64 ma_dr_wav__read_smpl_to_metadata_obj(ma_dr_wav__metadata_pa
|
|||
MA_DR_WAV_ASSERT(pMetadata->data.smpl.pSamplerSpecificData != NULL);
|
||||
ma_dr_wav__metadata_parser_read(pParser, pMetadata->data.smpl.pSamplerSpecificData, pMetadata->data.smpl.samplerSpecificDataSizeInBytes, &totalBytesRead);
|
||||
}
|
||||
} else {
|
||||
MA_DR_WAV_ZERO_OBJECT(&pMetadata->data.smpl);
|
||||
}
|
||||
}
|
||||
return totalBytesRead;
|
||||
|
|
@ -83149,19 +83181,13 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_
|
|||
newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8;
|
||||
newSample0 += nibble0 * pWav->msadpcm.delta[0];
|
||||
newSample0 = ma_dr_wav_clamp(newSample0, -32768, 32767);
|
||||
pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0xF0) >> 4)] * pWav->msadpcm.delta[0]) >> 8;
|
||||
if (pWav->msadpcm.delta[0] < 16) {
|
||||
pWav->msadpcm.delta[0] = 16;
|
||||
}
|
||||
pWav->msadpcm.delta[0] = (ma_int32)ma_dr_wav_clamp(((ma_int64)adaptationTable[((nibbles & 0xF0) >> 4)] * pWav->msadpcm.delta[0]) >> 8, 16, 0x7FFFFFFF);
|
||||
pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1];
|
||||
pWav->msadpcm.prevFrames[0][1] = newSample0;
|
||||
newSample1 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8;
|
||||
newSample1 += nibble1 * pWav->msadpcm.delta[0];
|
||||
newSample1 = ma_dr_wav_clamp(newSample1, -32768, 32767);
|
||||
pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0x0F) >> 0)] * pWav->msadpcm.delta[0]) >> 8;
|
||||
if (pWav->msadpcm.delta[0] < 16) {
|
||||
pWav->msadpcm.delta[0] = 16;
|
||||
}
|
||||
pWav->msadpcm.delta[0] = (ma_int32)ma_dr_wav_clamp(((ma_int64)adaptationTable[((nibbles & 0x0F) >> 0)] * pWav->msadpcm.delta[0]) >> 8, 16, 0x7FFFFFFF);
|
||||
pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1];
|
||||
pWav->msadpcm.prevFrames[0][1] = newSample1;
|
||||
pWav->msadpcm.cachedFrames[2] = newSample0;
|
||||
|
|
@ -83176,10 +83202,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_
|
|||
newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8;
|
||||
newSample0 += nibble0 * pWav->msadpcm.delta[0];
|
||||
newSample0 = ma_dr_wav_clamp(newSample0, -32768, 32767);
|
||||
pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0xF0) >> 4)] * pWav->msadpcm.delta[0]) >> 8;
|
||||
if (pWav->msadpcm.delta[0] < 16) {
|
||||
pWav->msadpcm.delta[0] = 16;
|
||||
}
|
||||
pWav->msadpcm.delta[0] = (ma_int32)ma_dr_wav_clamp(((ma_int64)adaptationTable[((nibbles & 0xF0) >> 4)] * pWav->msadpcm.delta[0]) >> 8, 16, 0x7FFFFFFF);
|
||||
pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1];
|
||||
pWav->msadpcm.prevFrames[0][1] = newSample0;
|
||||
if (pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff2Table)) {
|
||||
|
|
@ -83188,10 +83211,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_
|
|||
newSample1 = ((pWav->msadpcm.prevFrames[1][1] * coeff1Table[pWav->msadpcm.predictor[1]]) + (pWav->msadpcm.prevFrames[1][0] * coeff2Table[pWav->msadpcm.predictor[1]])) >> 8;
|
||||
newSample1 += nibble1 * pWav->msadpcm.delta[1];
|
||||
newSample1 = ma_dr_wav_clamp(newSample1, -32768, 32767);
|
||||
pWav->msadpcm.delta[1] = (adaptationTable[((nibbles & 0x0F) >> 0)] * pWav->msadpcm.delta[1]) >> 8;
|
||||
if (pWav->msadpcm.delta[1] < 16) {
|
||||
pWav->msadpcm.delta[1] = 16;
|
||||
}
|
||||
pWav->msadpcm.delta[1] = (ma_int32)ma_dr_wav_clamp(((ma_int64)adaptationTable[((nibbles & 0x0F) >> 0)] * pWav->msadpcm.delta[1]) >> 8, 16, 0x7FFFFFFF);
|
||||
pWav->msadpcm.prevFrames[1][0] = pWav->msadpcm.prevFrames[1][1];
|
||||
pWav->msadpcm.prevFrames[1][1] = newSample1;
|
||||
pWav->msadpcm.cachedFrames[2] = newSample0;
|
||||
|
|
@ -95825,7 +95845,7 @@ For more information, please refer to <http://unlicense.org/>
|
|||
===============================================================================
|
||||
ALTERNATIVE 2 - MIT No Attribution
|
||||
===============================================================================
|
||||
Copyright 2025 David Reid
|
||||
Copyright 2026 David Reid
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
|
|
|
|||
Loading…
Reference in New Issue