Merge branch 'master' into imatrix
This commit is contained in:
commit
090e07592e
|
|
@ -72,7 +72,7 @@ jobs:
|
|||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -DLLAMA_BUILD_BORINGSSL=ON
|
||||
cmake -B build -DLLAMA_BUILD_BORINGSSL=ON -DGGML_SCHED_NO_REALLOC=ON
|
||||
cmake --build build --config ${{ matrix.build_type }} -j ${env:NUMBER_OF_PROCESSORS} --target llama-server
|
||||
|
||||
- name: Python setup
|
||||
|
|
@ -108,7 +108,7 @@ jobs:
|
|||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -DLLAMA_BUILD_BORINGSSL=ON
|
||||
cmake -B build -DLLAMA_BUILD_BORINGSSL=ON -DGGML_SCHED_NO_REALLOC=ON
|
||||
cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS} --target llama-server
|
||||
|
||||
- name: Python setup
|
||||
|
|
|
|||
|
|
@ -1630,7 +1630,7 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co
|
|||
}
|
||||
auto msg = builder.result();
|
||||
if (!is_partial) {
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str());
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
|
@ -1663,7 +1663,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std
|
|||
mapper.from_ast(ctx.ast, result);
|
||||
}
|
||||
if (!is_partial) {
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str());
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
|
|
|||
220
common/chat.cpp
220
common/chat.cpp
|
|
@ -7,9 +7,6 @@
|
|||
#include "log.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
// #include <minja/chat-template.hpp>
|
||||
// #include <minja/minja.hpp>
|
||||
|
||||
#include "jinja/parser.h"
|
||||
#include "jinja/value.h"
|
||||
#include "jinja/runtime.h"
|
||||
|
|
@ -56,39 +53,73 @@ static bool has_content_or_tool_calls(const common_chat_msg & msg) {
|
|||
return !msg.content.empty() || !msg.tool_calls.empty();
|
||||
}
|
||||
|
||||
template <>
|
||||
json common_chat_msg::to_json_oaicompat() const
|
||||
{
|
||||
json message {
|
||||
{"role", "assistant"},
|
||||
};
|
||||
if (!reasoning_content.empty()) {
|
||||
message["reasoning_content"] = reasoning_content;
|
||||
json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const {
|
||||
if (!content.empty() && !content_parts.empty()) {
|
||||
throw std::runtime_error("Cannot specify both content and content_parts");
|
||||
}
|
||||
if (content.empty() && !tool_calls.empty()) {
|
||||
message["content"] = json();
|
||||
json jmsg {
|
||||
{"role", role},
|
||||
};
|
||||
if (!content.empty()) {
|
||||
jmsg["content"] = content;
|
||||
} else if (!content_parts.empty()) {
|
||||
if (concat_typed_text) {
|
||||
std::string text;
|
||||
for (const auto & part : content_parts) {
|
||||
if (part.type != "text") {
|
||||
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
|
||||
continue;
|
||||
}
|
||||
if (!text.empty()) {
|
||||
text += '\n';
|
||||
}
|
||||
text += part.text;
|
||||
}
|
||||
jmsg["content"] = text;
|
||||
} else {
|
||||
auto & parts = jmsg["content"] = json::array();
|
||||
for (const auto & part : content_parts) {
|
||||
parts.push_back({
|
||||
{"type", part.type},
|
||||
{"text", part.text},
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
message["content"] = content;
|
||||
jmsg["content"] = "";
|
||||
}
|
||||
if (!reasoning_content.empty()) {
|
||||
jmsg["reasoning_content"] = reasoning_content;
|
||||
}
|
||||
if (!tool_name.empty()) {
|
||||
jmsg["name"] = tool_name;
|
||||
}
|
||||
if (!tool_call_id.empty()) {
|
||||
jmsg["tool_call_id"] = tool_call_id;
|
||||
}
|
||||
if (!tool_calls.empty()) {
|
||||
auto arr = json::array();
|
||||
for (const auto & tc : tool_calls) {
|
||||
arr.push_back({
|
||||
jmsg["tool_calls"] = json::array();
|
||||
auto & jtool_calls = jmsg["tool_calls"];
|
||||
for (const auto & tool_call : tool_calls) {
|
||||
json tc {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", tc.name},
|
||||
{"arguments", tc.arguments},
|
||||
{"name", tool_call.name},
|
||||
{"arguments", tool_call.arguments},
|
||||
}},
|
||||
{"id", tc.id},
|
||||
// // Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
|
||||
// // We only generate a random id for the ones that don't generate one by themselves
|
||||
// // (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
|
||||
// {"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
|
||||
});
|
||||
};
|
||||
if (!tool_call.id.empty()) {
|
||||
tc["id"] = tool_call.id;
|
||||
}
|
||||
// Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
|
||||
// We only generate a random id for the ones that don't generate one by themselves
|
||||
// (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
|
||||
// {"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
|
||||
jtool_calls.push_back(tc);
|
||||
}
|
||||
message["tool_calls"] = arr;
|
||||
}
|
||||
return message;
|
||||
|
||||
return jmsg;
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new) {
|
||||
|
|
@ -256,7 +287,6 @@ bool common_chat_templates_support_enable_thinking(const common_chat_templates *
|
|||
return rendered_no_thinking.prompt != rendered_with_thinking.prompt;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
|
||||
std::vector<common_chat_msg> msgs;
|
||||
|
||||
|
|
@ -350,80 +380,15 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
|
|||
return msgs;
|
||||
}
|
||||
|
||||
template <>
|
||||
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
|
||||
json messages = json::array();
|
||||
for (const auto & msg : msgs) {
|
||||
if (!msg.content.empty() && !msg.content_parts.empty()) {
|
||||
throw std::runtime_error("Cannot specify both content and content_parts");
|
||||
}
|
||||
json jmsg {
|
||||
{"role", msg.role},
|
||||
};
|
||||
if (!msg.content.empty()) {
|
||||
jmsg["content"] = msg.content;
|
||||
} else if (!msg.content_parts.empty()) {
|
||||
if (concat_typed_text) {
|
||||
std::string text;
|
||||
for (const auto & part : msg.content_parts) {
|
||||
if (part.type != "text") {
|
||||
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
|
||||
continue;
|
||||
}
|
||||
if (!text.empty()) {
|
||||
text += '\n';
|
||||
}
|
||||
text += part.text;
|
||||
}
|
||||
jmsg["content"] = text;
|
||||
} else {
|
||||
auto & parts = jmsg["content"] = json::array();
|
||||
for (const auto & part : msg.content_parts) {
|
||||
parts.push_back({
|
||||
{"type", part.type},
|
||||
{"text", part.text},
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
jmsg["content"] = "";
|
||||
}
|
||||
if (!msg.reasoning_content.empty()) {
|
||||
jmsg["reasoning_content"] = msg.reasoning_content;
|
||||
}
|
||||
if (!msg.tool_name.empty()) {
|
||||
jmsg["name"] = msg.tool_name;
|
||||
}
|
||||
if (!msg.tool_call_id.empty()) {
|
||||
jmsg["tool_call_id"] = msg.tool_call_id;
|
||||
}
|
||||
if (!msg.tool_calls.empty()) {
|
||||
auto & tool_calls = jmsg["tool_calls"] = json::array();
|
||||
for (const auto & tool_call : msg.tool_calls) {
|
||||
json tc {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", tool_call.name},
|
||||
{"arguments", tool_call.arguments},
|
||||
}},
|
||||
};
|
||||
if (!tool_call.id.empty()) {
|
||||
tc["id"] = tool_call.id;
|
||||
}
|
||||
tool_calls.push_back(tc);
|
||||
}
|
||||
}
|
||||
json jmsg = msg.to_json_oaicompat(concat_typed_text);
|
||||
messages.push_back(jmsg);
|
||||
}
|
||||
return messages;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
|
||||
return common_chat_msgs_parse_oaicompat(json::parse(messages));
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
|
||||
std::vector<common_chat_tool> result;
|
||||
|
||||
|
|
@ -459,12 +424,6 @@ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & too
|
|||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
|
||||
return common_chat_tools_parse_oaicompat(json::parse(tools));
|
||||
}
|
||||
|
||||
template <>
|
||||
json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
|
||||
if (tools.empty()) {
|
||||
return json();
|
||||
|
|
@ -484,7 +443,7 @@ json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & t
|
|||
return result;
|
||||
}
|
||||
|
||||
template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
|
||||
json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
|
||||
json delta = json::object();
|
||||
if (!diff.reasoning_content_delta.empty()) {
|
||||
delta["reasoning_content"] = diff.reasoning_content_delta;
|
||||
|
|
@ -2691,6 +2650,45 @@ static common_chat_params common_chat_params_init_exaone_moe(const common_chat_t
|
|||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_translate_gemma(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
// This template does not support tools or reasoning
|
||||
// we just need to transform the messages into the correct schema
|
||||
|
||||
templates_params inputs_new = inputs;
|
||||
json & messages = inputs_new.messages;
|
||||
|
||||
GGML_ASSERT(messages.is_array());
|
||||
for (auto & message : messages) {
|
||||
if (message.contains("role") && message["role"].get<std::string>() != "user") {
|
||||
continue;
|
||||
}
|
||||
if (!message.contains("content")) {
|
||||
message["content"] = json::array();
|
||||
}
|
||||
if (message.contains("content") && !message["content"].is_array()) {
|
||||
auto content_str = message["content"].get<std::string>();
|
||||
// default to en-GB if not specified (to make common_chat_format_example works)
|
||||
auto src_lang = message.contains("source_lang_code") ? message["source_lang_code"].get<std::string>() : "en-GB";
|
||||
auto tgt_lang = message.contains("target_lang_code") ? message["target_lang_code"].get<std::string>() : "en-GB";
|
||||
message["content"] = json::array({
|
||||
json{
|
||||
{"type", "text"},
|
||||
{"text", content_str},
|
||||
{"source_lang_code", src_lang},
|
||||
{"target_lang_code", tgt_lang},
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
data.prompt = apply(tmpl, inputs_new, std::nullopt, std::nullopt);
|
||||
data.format = COMMON_CHAT_FORMAT_GENERIC;
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs);
|
||||
|
|
@ -2867,13 +2865,13 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
const struct common_chat_templates_inputs & inputs)
|
||||
{
|
||||
templates_params params;
|
||||
params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
|
||||
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
|
||||
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
|
||||
? *tmpls->template_tool_use
|
||||
: *tmpls->template_default;
|
||||
const auto & src = tmpl.source();
|
||||
const auto & caps = tmpl.original_caps();
|
||||
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
|
||||
params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.reasoning_format = inputs.reasoning_format;
|
||||
|
|
@ -2943,6 +2941,10 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
src.find("<arg_value>") != std::string::npos &&
|
||||
params.json_schema.is_null()) {
|
||||
workaround::func_args_not_string(params.messages);
|
||||
if (!params.extra_context.contains("clear_thinking")) {
|
||||
// by default, do not clear reasoning_content (added since GLM-4.7)
|
||||
params.extra_context["clear_thinking"] = false;
|
||||
}
|
||||
return common_chat_params_init_glm_4_5(tmpl, params);
|
||||
}
|
||||
|
||||
|
|
@ -3082,6 +3084,12 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
return common_chat_params_init_solar_open(tmpl, params);
|
||||
}
|
||||
|
||||
// TranslateGemma
|
||||
if (src.find("[source_lang_code]") != std::string::npos &&
|
||||
src.find("[target_lang_code]") != std::string::npos) {
|
||||
return common_chat_params_init_translate_gemma(tmpl, params);
|
||||
}
|
||||
|
||||
// Plain handler (no tools)
|
||||
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return common_chat_params_init_without_tools(tmpl, params);
|
||||
|
|
@ -3174,3 +3182,9 @@ common_chat_params common_chat_templates_apply(
|
|||
? common_chat_templates_apply_jinja(tmpls, inputs)
|
||||
: common_chat_templates_apply_legacy(tmpls, inputs);
|
||||
}
|
||||
|
||||
std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates) {
|
||||
GGML_ASSERT(chat_templates != nullptr);
|
||||
GGML_ASSERT(chat_templates->template_default != nullptr);
|
||||
return chat_templates->template_default->caps.to_map();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@
|
|||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
struct common_chat_templates;
|
||||
|
||||
struct common_chat_tool_call {
|
||||
|
|
@ -26,6 +28,11 @@ struct common_chat_msg_content_part {
|
|||
std::string type;
|
||||
std::string text;
|
||||
|
||||
// TODO @ngxson : no known chat templates support reasoning_content in content parts yet
|
||||
// this can be useful for models with interleaved thinking (like Kimi-K2)
|
||||
// if you see any templates explicitly support this, please ping me
|
||||
// std::string reasoning_content;
|
||||
|
||||
bool operator==(const common_chat_msg_content_part & other) const {
|
||||
return type == other.type && text == other.text;
|
||||
}
|
||||
|
|
@ -40,7 +47,7 @@ struct common_chat_msg {
|
|||
std::string tool_name;
|
||||
std::string tool_call_id;
|
||||
|
||||
template <class T> T to_json_oaicompat() const;
|
||||
nlohmann::ordered_json to_json_oaicompat(bool concat_typed_text = false) const;
|
||||
|
||||
bool empty() const {
|
||||
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
|
||||
|
|
@ -232,13 +239,13 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin
|
|||
bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates);
|
||||
|
||||
// Parses a JSON array of messages in OpenAI's chat completion API format.
|
||||
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
|
||||
template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
|
||||
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const nlohmann::ordered_json & messages);
|
||||
nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
|
||||
|
||||
// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
|
||||
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
|
||||
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools);
|
||||
nlohmann::ordered_json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
||||
|
||||
template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|
||||
nlohmann::ordered_json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|
||||
|
||||
// get template caps, useful for reporting to server /props endpoint
|
||||
std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates);
|
||||
|
|
|
|||
|
|
@ -61,14 +61,23 @@ static void caps_print_stats(value & v, const std::string & path) {
|
|||
ops.c_str());
|
||||
}
|
||||
|
||||
std::map<std::string, bool> caps::to_map() const {
|
||||
return {
|
||||
{"requires_typed_content", requires_typed_content},
|
||||
{"supports_tools", supports_tools},
|
||||
{"supports_tool_calls", supports_tool_calls},
|
||||
{"supports_parallel_tool_calls", supports_parallel_tool_calls},
|
||||
{"supports_system_role", supports_system_role},
|
||||
{"supports_preserve_reasoning", supports_preserve_reasoning},
|
||||
};
|
||||
}
|
||||
|
||||
std::string caps::to_string() const {
|
||||
std::ostringstream ss;
|
||||
ss << "Caps(\n";
|
||||
ss << " requires_typed_content=" << requires_typed_content << "\n";
|
||||
ss << " supports_tools=" << supports_tools << "\n";
|
||||
ss << " supports_tool_calls=" << supports_tool_calls << "\n";
|
||||
ss << " supports_parallel_tool_calls=" << supports_parallel_tool_calls << "\n";
|
||||
ss << " supports_system_role=" << supports_system_role << "\n";
|
||||
for (const auto & [key, value] : to_map()) {
|
||||
ss << " " << key << "=" << (value ? "true" : "false") << "\n";
|
||||
}
|
||||
ss << ")";
|
||||
return ss.str();
|
||||
}
|
||||
|
|
@ -229,6 +238,40 @@ caps caps_get(jinja::program & prog) {
|
|||
}
|
||||
);
|
||||
|
||||
// case: preserve reasoning content in chat history
|
||||
caps_try_execute(
|
||||
prog,
|
||||
[&]() {
|
||||
// messages
|
||||
return json::array({
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "User message"}
|
||||
},
|
||||
{
|
||||
{"role", "assistant"},
|
||||
{"content", "Assistant message"},
|
||||
{"reasoning_content", "Reasoning content"}
|
||||
},
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "User message"}
|
||||
},
|
||||
});
|
||||
},
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array();
|
||||
},
|
||||
[&](bool, value & messages, value &) {
|
||||
auto & content = messages->at(1)->at("reasoning_content");
|
||||
caps_print_stats(content, "messages[1].reasoning_content");
|
||||
if (content->stats.used) {
|
||||
result.supports_preserve_reasoning = true;
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
JJ_DEBUG("%s\n", result.to_string().c_str());
|
||||
|
||||
return result;
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include "runtime.h"
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
namespace jinja {
|
||||
|
||||
|
|
@ -11,14 +12,17 @@ struct caps {
|
|||
bool supports_tool_calls = true;
|
||||
bool supports_system_role = true;
|
||||
bool supports_parallel_tool_calls = true;
|
||||
bool supports_preserve_reasoning = false; // support assistant message with reasoning_content
|
||||
|
||||
bool requires_typed_content = false; // default: use string content
|
||||
|
||||
// for reporting on server
|
||||
std::map<std::string, bool> to_map() const;
|
||||
|
||||
// for debugging
|
||||
std::string to_string() const;
|
||||
};
|
||||
|
||||
caps caps_get(jinja::program & prog);
|
||||
void debug_print_caps(const caps & c);
|
||||
|
||||
} // namespace jinja
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -3,6 +3,7 @@
|
|||
set -e
|
||||
|
||||
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
||||
BUILD_DIR="${2:-"$BUILD_DIR"}"
|
||||
|
||||
# Final check if we have a model path
|
||||
if [ -z "$CONVERTED_MODEL" ]; then
|
||||
|
|
@ -25,9 +26,13 @@ mkdir -p ppl
|
|||
OUTPUTFILE="ppl/$(basename $CONVERTED_MODEL).kld"
|
||||
echo "Model: $CONVERTED_MODEL"
|
||||
|
||||
cmake --build ../../build --target llama-perplexity -j8
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
../.././build/bin/llama-perplexity -m $CONVERTED_MODEL \
|
||||
cmake --build $BUILD_DIR --target llama-perplexity -j8
|
||||
|
||||
${BUILD_DIR}/bin/llama-perplexity -m $CONVERTED_MODEL \
|
||||
-f ppl/wikitext-2-raw/wiki.test.raw \
|
||||
--kl-divergence-base $OUTPUTFILE
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
set -e
|
||||
|
||||
QUANTIZED_MODEL="${1:-"$QUANTIZED_MODEL"}"
|
||||
BUILD_DIR="${2:-"$BUILD_DIR"}"
|
||||
|
||||
if [ -z "$QUANTIZED_MODEL" ]; then
|
||||
echo "Error: Model path must be provided either as:" >&2
|
||||
|
|
@ -20,8 +21,12 @@ if [ ! -d "ppl/wikitext-2-raw" ]; then
|
|||
popd
|
||||
fi
|
||||
|
||||
cmake --build ../../build --target llama-perplexity -j8
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
../.././build/bin/llama-perplexity -m $QUANTIZED_MODEL -f ppl/wikitext-2-raw/wiki.test.raw
|
||||
cmake --build $BUILD_DIR --target llama-perplexity -j8
|
||||
|
||||
${BUILD_DIR}/bin/llama-perplexity -m $QUANTIZED_MODEL -f ppl/wikitext-2-raw/wiki.test.raw
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@
|
|||
set -e
|
||||
|
||||
QUANTIZED_MODEL="${1:-"$QUANTIZED_MODEL"}"
|
||||
LOGITS_FILE="${1:-"$LOGITS_FILE"}"
|
||||
LOGITS_FILE="${2:-"$LOGITS_FILE"}"
|
||||
BUILD_DIR="${3:-"$BUILD_DIR"}"
|
||||
|
||||
if [ -z "$QUANTIZED_MODEL" ]; then
|
||||
echo "Error: Model path must be provided either as:" >&2
|
||||
|
|
@ -18,11 +19,15 @@ if [ ! -f ${LOGITS_FILE} ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
echo "Model: $QUANTIZED_MODEL"
|
||||
echo "Data file: $LOGITS_FILE"
|
||||
|
||||
cmake --build ../../build --target llama-perplexity -j8
|
||||
cmake --build $BUILD_DIR --target llama-perplexity -j8
|
||||
|
||||
../.././build/bin/llama-perplexity -m $QUANTIZED_MODEL \
|
||||
${BUILD_DIR}/bin/llama-perplexity -m $QUANTIZED_MODEL \
|
||||
--kl-divergence-base $LOGITS_FILE \
|
||||
--kl-divergence
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
|||
QUANTIZED_TYPE="${2:-"$QUANTIZED_TYPE"}"
|
||||
TOKEN_EMBD_TYPE="${3:-"${TOKEN_EMBD_TYPE}"}"
|
||||
OUTPUT_TYPE="${4:-"${OUTPUT_TYPE}"}"
|
||||
BUILD_DIR="${5:-"$BUILD_DIR"}"
|
||||
QUANTIZED_MODEL=$CONVERTED_MODEL
|
||||
|
||||
# Final check if we have a model path
|
||||
|
|
@ -33,12 +34,16 @@ else
|
|||
exit 1
|
||||
fi
|
||||
|
||||
cmake --build ../../build --target llama-quantize -j8
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
cmake --build $BUILD_DIR --target llama-quantize -j8
|
||||
|
||||
echo $TOKEN_EMBD_TYPE
|
||||
echo $OUTPUT_TYPE
|
||||
|
||||
CMD_ARGS=("../../build/bin/llama-quantize")
|
||||
CMD_ARGS=("${BUILD_DIR}/bin/llama-quantize")
|
||||
[[ -n "$TOKEN_EMBD_TYPE" ]] && CMD_ARGS+=("--token-embedding-type" "$TOKEN_EMBD_TYPE")
|
||||
[[ -n "$OUTPUT_TYPE" ]] && CMD_ARGS+=("--output-tensor-type" "$OUTPUT_TYPE")
|
||||
CMD_ARGS+=("$CONVERTED_MODEL" "$QUANTIZED_MODEL" "$QUANTIZED_TYPE")
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ set -e
|
|||
#
|
||||
# First try command line argument, then environment variable, then file
|
||||
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
||||
BUILD_DIR="${2:-"$BUILD_DIR"}"
|
||||
|
||||
# Final check if we have a model path
|
||||
if [ -z "$CONVERTED_MODEL" ]; then
|
||||
|
|
@ -13,10 +14,14 @@ if [ -z "$CONVERTED_MODEL" ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
echo $CONVERTED_MODEL
|
||||
|
||||
cmake --build ../../build --target llama-server
|
||||
cmake --build $BUILD_DIR --target llama-server
|
||||
|
||||
../../build/bin/llama-server -m $CONVERTED_MODEL \
|
||||
${BUILD_DIR}/bin/llama-server -m $CONVERTED_MODEL \
|
||||
--embedding \
|
||||
--pooling none
|
||||
|
|
|
|||
|
|
@ -38,9 +38,10 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
|
|
@ -48,9 +49,10 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
|
|
@ -70,12 +72,14 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||
|
|
@ -94,9 +98,10 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
|
|
@ -104,9 +109,10 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
|
|
@ -126,9 +132,10 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
|
|
@ -136,9 +143,10 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
|
|
@ -165,18 +173,20 @@
|
|||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
|
|
@ -202,9 +212,10 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
|
|
@ -212,9 +223,10 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
|
|
@ -242,9 +254,10 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
|
|
@ -252,9 +265,10 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
|
|
|
|||
|
|
@ -25,9 +25,8 @@
|
|||
#define UNUSED GGML_UNUSED
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD))
|
||||
static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
|
||||
int16x8_t * out_mins,
|
||||
int8_t * out_scales) {
|
||||
// Helper for decoding scales and mins of Q4_K and Q5_K block formats
|
||||
static inline void decode_q_Kx8_6bit_scales(const uint8_t * scales_in, int16x8_t * out_mins, int8_t * out_scales) {
|
||||
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
||||
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
||||
constexpr uint32_t kmask3 = 0x03030303;
|
||||
|
|
@ -561,7 +560,7 @@ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q4sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
||||
}
|
||||
|
||||
|
|
@ -701,7 +700,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
|||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q4sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
||||
}
|
||||
|
||||
|
|
@ -786,6 +785,293 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
|||
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q5_K_8x8_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
constexpr int col_pairs = ncols_interleaved / 2;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
const uint8x16_t mone = vdupq_n_u8(1);
|
||||
const uint8x16_t mtwo = vdupq_n_u8(2);
|
||||
|
||||
// 1x8 tile = 2 x 4
|
||||
float32x4_t acc_f32[ncols_interleaved / 4];
|
||||
|
||||
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < ncols_interleaved / 4; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3
|
||||
float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7
|
||||
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
||||
float32x4_t sb_scale_0 = vmulq_f32(q5_d_0, q8_d);
|
||||
float32x4_t sb_scale_1 = vmulq_f32(q5_d_1, q8_d);
|
||||
float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3
|
||||
float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7
|
||||
float32x4_t sb_min_0 = vmulq_f32(q5_dmin_0, q8_d);
|
||||
float32x4_t sb_min_1 = vmulq_f32(q5_dmin_1, q8_d);
|
||||
|
||||
// 2 sb each iteration
|
||||
int32x4_t acc_lo[col_pairs];
|
||||
int32x4_t acc_hi[col_pairs];
|
||||
|
||||
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
|
||||
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
|
||||
int16_t bsums_arr[8];
|
||||
vst1q_s16(bsums_arr, bsums);
|
||||
|
||||
// Load qh once per block and shift after each subblock
|
||||
const uint8_t * qh_base = q5_ptr[b].qh;
|
||||
uint8x16_t qh[col_pairs][4];
|
||||
for (int cp = 0; cp < col_pairs; cp++) {
|
||||
qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
|
||||
qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
|
||||
qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
|
||||
qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
|
||||
}
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
for (int i = 0; i < col_pairs; i++) {
|
||||
acc_lo[i] = vdupq_n_s32(0);
|
||||
acc_hi[i] = vdupq_n_s32(0);
|
||||
}
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later
|
||||
int16x8_t q5sb_scales[2];
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q5sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
|
||||
q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
|
||||
}
|
||||
|
||||
const uint8_t * qs_base = q5_ptr[b].qs + sb * QK_K;
|
||||
|
||||
// Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns
|
||||
const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
|
||||
int8x16_t q8_qs[8];
|
||||
for (int i = 0; i < 8; i++) {
|
||||
q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
|
||||
}
|
||||
|
||||
// Q5s column pair loop unrolled
|
||||
{
|
||||
// Cols 01
|
||||
uint8x16_t qs_0 = vld1q_u8(qs_base);
|
||||
uint8x16_t qs_1 = vld1q_u8(qs_base + 64);
|
||||
uint8x16_t qs_2 = vld1q_u8(qs_base + 128);
|
||||
uint8x16_t qs_3 = vld1q_u8(qs_base + 192);
|
||||
|
||||
uint8x16_t hbit_lo_0 = vandq_u8(qh[0][0], mone);
|
||||
uint8x16_t hbit_lo_1 = vandq_u8(qh[0][1], mone);
|
||||
uint8x16_t hbit_lo_2 = vandq_u8(qh[0][2], mone);
|
||||
uint8x16_t hbit_lo_3 = vandq_u8(qh[0][3], mone);
|
||||
uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[0][0], mtwo), 3);
|
||||
uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[0][1], mtwo), 3);
|
||||
uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3);
|
||||
uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[0][3], mtwo), 3);
|
||||
|
||||
qh[0][0] = vshrq_n_u8(qh[0][0], 2);
|
||||
qh[0][1] = vshrq_n_u8(qh[0][1], 2);
|
||||
qh[0][2] = vshrq_n_u8(qh[0][2], 2);
|
||||
qh[0][3] = vshrq_n_u8(qh[0][3], 2);
|
||||
|
||||
acc_lo[0] = ggml_vdotq_s32(
|
||||
acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
|
||||
acc_lo[0] = ggml_vdotq_s32(
|
||||
acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
|
||||
acc_lo[0] = ggml_vdotq_s32(
|
||||
acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
|
||||
acc_lo[0] = ggml_vdotq_s32(
|
||||
acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
|
||||
acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
|
||||
q8_qs[4]);
|
||||
acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
|
||||
q8_qs[5]);
|
||||
acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
|
||||
q8_qs[6]);
|
||||
acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
|
||||
q8_qs[7]);
|
||||
|
||||
// Cols 23
|
||||
qs_0 = vld1q_u8(qs_base + 16);
|
||||
qs_1 = vld1q_u8(qs_base + 80);
|
||||
qs_2 = vld1q_u8(qs_base + 144);
|
||||
qs_3 = vld1q_u8(qs_base + 208);
|
||||
|
||||
hbit_lo_0 = vandq_u8(qh[1][0], mone);
|
||||
hbit_lo_1 = vandq_u8(qh[1][1], mone);
|
||||
hbit_lo_2 = vandq_u8(qh[1][2], mone);
|
||||
hbit_lo_3 = vandq_u8(qh[1][3], mone);
|
||||
hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[1][0], mtwo), 3);
|
||||
hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[1][1], mtwo), 3);
|
||||
hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[1][2], mtwo), 3);
|
||||
hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[1][3], mtwo), 3);
|
||||
|
||||
qh[1][0] = vshrq_n_u8(qh[1][0], 2);
|
||||
qh[1][1] = vshrq_n_u8(qh[1][1], 2);
|
||||
qh[1][2] = vshrq_n_u8(qh[1][2], 2);
|
||||
qh[1][3] = vshrq_n_u8(qh[1][3], 2);
|
||||
|
||||
acc_lo[1] = ggml_vdotq_s32(
|
||||
acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
|
||||
acc_lo[1] = ggml_vdotq_s32(
|
||||
acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
|
||||
acc_lo[1] = ggml_vdotq_s32(
|
||||
acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
|
||||
acc_lo[1] = ggml_vdotq_s32(
|
||||
acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
|
||||
acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
|
||||
q8_qs[4]);
|
||||
acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
|
||||
q8_qs[5]);
|
||||
acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
|
||||
q8_qs[6]);
|
||||
acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
|
||||
q8_qs[7]);
|
||||
|
||||
// Cols 45
|
||||
qs_0 = vld1q_u8(qs_base + 32);
|
||||
qs_1 = vld1q_u8(qs_base + 96);
|
||||
qs_2 = vld1q_u8(qs_base + 160);
|
||||
qs_3 = vld1q_u8(qs_base + 224);
|
||||
|
||||
hbit_lo_0 = vandq_u8(qh[2][0], mone);
|
||||
hbit_lo_1 = vandq_u8(qh[2][1], mone);
|
||||
hbit_lo_2 = vandq_u8(qh[2][2], mone);
|
||||
hbit_lo_3 = vandq_u8(qh[2][3], mone);
|
||||
hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[2][0], mtwo), 3);
|
||||
hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[2][1], mtwo), 3);
|
||||
hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[2][2], mtwo), 3);
|
||||
hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[2][3], mtwo), 3);
|
||||
|
||||
qh[2][0] = vshrq_n_u8(qh[2][0], 2);
|
||||
qh[2][1] = vshrq_n_u8(qh[2][1], 2);
|
||||
qh[2][2] = vshrq_n_u8(qh[2][2], 2);
|
||||
qh[2][3] = vshrq_n_u8(qh[2][3], 2);
|
||||
|
||||
acc_lo[2] = ggml_vdotq_s32(
|
||||
acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
|
||||
acc_lo[2] = ggml_vdotq_s32(
|
||||
acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
|
||||
acc_lo[2] = ggml_vdotq_s32(
|
||||
acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
|
||||
acc_lo[2] = ggml_vdotq_s32(
|
||||
acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
|
||||
acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
|
||||
q8_qs[4]);
|
||||
acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
|
||||
q8_qs[5]);
|
||||
acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
|
||||
q8_qs[6]);
|
||||
acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
|
||||
q8_qs[7]);
|
||||
|
||||
// Cols 45
|
||||
qs_0 = vld1q_u8(qs_base + 48);
|
||||
qs_1 = vld1q_u8(qs_base + 112);
|
||||
qs_2 = vld1q_u8(qs_base + 176);
|
||||
qs_3 = vld1q_u8(qs_base + 240);
|
||||
|
||||
hbit_lo_0 = vandq_u8(qh[3][0], mone);
|
||||
hbit_lo_1 = vandq_u8(qh[3][1], mone);
|
||||
hbit_lo_2 = vandq_u8(qh[3][2], mone);
|
||||
hbit_lo_3 = vandq_u8(qh[3][3], mone);
|
||||
hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[3][0], mtwo), 3);
|
||||
hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[3][1], mtwo), 3);
|
||||
hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[3][2], mtwo), 3);
|
||||
hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[3][3], mtwo), 3);
|
||||
|
||||
qh[3][0] = vshrq_n_u8(qh[3][0], 2);
|
||||
qh[3][1] = vshrq_n_u8(qh[3][1], 2);
|
||||
qh[3][2] = vshrq_n_u8(qh[3][2], 2);
|
||||
qh[3][3] = vshrq_n_u8(qh[3][3], 2);
|
||||
|
||||
acc_lo[3] = ggml_vdotq_s32(
|
||||
acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
|
||||
acc_lo[3] = ggml_vdotq_s32(
|
||||
acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
|
||||
acc_lo[3] = ggml_vdotq_s32(
|
||||
acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
|
||||
acc_lo[3] = ggml_vdotq_s32(
|
||||
acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
|
||||
acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
|
||||
q8_qs[4]);
|
||||
acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
|
||||
q8_qs[5]);
|
||||
acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
|
||||
q8_qs[6]);
|
||||
acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
|
||||
q8_qs[7]);
|
||||
}
|
||||
|
||||
// Prepare bsum vectors for bias computation
|
||||
// Each pair of subblocks share the same bsums
|
||||
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
|
||||
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
|
||||
|
||||
// Iterates over a pair of column pairs (4 columns) to use a single 128 register
|
||||
// p = 0 -> 0123 p2 -> 4567
|
||||
for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
|
||||
int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q5sb_scales[0]) : vget_high_s16(q5sb_scales[0]);
|
||||
int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q5sb_scales[1]) : vget_high_s16(q5sb_scales[1]);
|
||||
int16x4_t group_mins_lo = p == 0 ? vget_low_s16(q5sb_mins[0]) : vget_high_s16(q5sb_mins[0]);
|
||||
int16x4_t group_mins_hi = p == 0 ? vget_low_s16(q5sb_mins[1]) : vget_high_s16(q5sb_mins[1]);
|
||||
float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
|
||||
float32x4_t sb_min = p == 0 ? sb_min_0 : sb_min_1;
|
||||
|
||||
// 0123 or 4567
|
||||
float32x4_t sumf_0 =
|
||||
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
|
||||
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
|
||||
|
||||
float32x4_t sumf_1 =
|
||||
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
|
||||
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
|
||||
|
||||
// FUSED BIAS: Compute and subtract bias immediately
|
||||
// bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min
|
||||
int32x4_t bias = vmull_s16(bsums_vec_lo, group_mins_lo);
|
||||
bias = vmlal_s16(bias, bsums_vec_hi, group_mins_hi);
|
||||
float32x4_t bias_f32 = vcvtq_f32_s32(bias);
|
||||
acc_f32[i] = vmlsq_f32(acc_f32[i], sb_min, bias_f32);
|
||||
}
|
||||
} // for sb
|
||||
} // for b
|
||||
|
||||
int base = x * ncols_interleaved;
|
||||
vst1q_f32(s + base, acc_f32[0]);
|
||||
vst1q_f32(s + base + 4, acc_f32[1]);
|
||||
} // for x
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q8_0_4x4_q8_0(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
|
|
@ -2431,7 +2717,7 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q4sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
||||
}
|
||||
|
||||
|
|
@ -2595,7 +2881,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
|||
int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
|
||||
for (int i = 0; i < 2; i++) {
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
|
||||
decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
|
||||
}
|
||||
|
||||
// q8_ptr[b].qs has interleaved Q8 rows (01, 23)
|
||||
|
|
@ -2738,6 +3024,252 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
|||
ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q5_K_8x8_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
constexpr int col_pairs = ncols_interleaved / 2;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
const uint8x16_t mone = vdupq_n_u8(1);
|
||||
const uint8x16_t mtwo = vdupq_n_u8(2);
|
||||
|
||||
// 8 accumulators: 2 row pairs × 4 col pairs
|
||||
float32x4_t acc_f32[blocklen];
|
||||
|
||||
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < blocklen; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
// bsums pairs belongs to the same q8_k subblock
|
||||
const int16x8_t bsums[4]{
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
||||
};
|
||||
int16_t bsums_arr[4][8];
|
||||
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
|
||||
}
|
||||
|
||||
int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
|
||||
int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
|
||||
int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
|
||||
for (int i = 0; i < 8; i++) {
|
||||
acc[i] = vdupq_n_s32(0);
|
||||
bias_acc[i] = vdupq_n_s32(0);
|
||||
}
|
||||
|
||||
// Load qh once per block and shift after each subblock
|
||||
const uint8_t * qh_base = q5_ptr[b].qh;
|
||||
uint8x16_t qh[col_pairs][4];
|
||||
for (int cp = 0; cp < col_pairs; cp++) {
|
||||
qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
|
||||
qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
|
||||
qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
|
||||
qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
|
||||
}
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
int8_t q5sb_scales[2][8];
|
||||
int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later
|
||||
for (int i = 0; i < 2; i++) {
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], q5sb_scales[i]);
|
||||
}
|
||||
|
||||
// q8_ptr[b].qs has interleaved Q8 rows (01, 23)
|
||||
const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
|
||||
|
||||
int8x16_t q8_qs_01[8];
|
||||
int8x16_t q8_qs_23[8];
|
||||
|
||||
// Load 32-byte per row pair, 1 subblock each time
|
||||
for (int i = 0; i < 8; i++) {
|
||||
const int offset = i * 32; // 16 for row 01, 16 for row 23
|
||||
q8_qs_01[i] = vld1q_s8(q8_base + offset);
|
||||
q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
|
||||
}
|
||||
|
||||
const int8x16_t q8s[2][8] = {
|
||||
{ q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], q8_qs_01[4], q8_qs_01[5], q8_qs_01[6],
|
||||
q8_qs_01[7] },
|
||||
{ q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], q8_qs_23[4], q8_qs_23[5], q8_qs_23[6],
|
||||
q8_qs_23[7] },
|
||||
};
|
||||
|
||||
// Q5s columns iterated in pairs (01, 23, 45, 67)
|
||||
for (int cp = 0; cp < col_pairs; cp++) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
sb_acc[i] = vdupq_n_s32(0);
|
||||
}
|
||||
|
||||
uint8x16_t qs_cp_0 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
|
||||
uint8x16_t qs_cp_1 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
|
||||
uint8x16_t qs_cp_2 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
|
||||
uint8x16_t qs_cp_3 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
|
||||
|
||||
// This is the only part of the algorithm that differs with Q4_K
|
||||
// Extract High bits and pack into 5 bit weights
|
||||
uint8x16_t hbit_lo_0 = vandq_u8(qh[cp][0], mone);
|
||||
uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3);
|
||||
qh[cp][0] = vshrq_n_u8(qh[cp][0], 2);
|
||||
// Same as Q4_K, i8mm to dequantize the weights.
|
||||
const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4));
|
||||
int32x4_t acc_0 = sb_acc[0];
|
||||
acc_0 = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]);
|
||||
int32x4_t acc_2 = sb_acc[2];
|
||||
acc_2 = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]);
|
||||
const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0));
|
||||
int32x4_t acc_1 = sb_acc[1];
|
||||
acc_1 = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]);
|
||||
int32x4_t acc_3 = sb_acc[3];
|
||||
acc_3 = vmmlaq_s32(acc_3, qs_hi_0, q8s[1][4]);
|
||||
|
||||
// Repeat for the other 3 columns (8..15, 16..23, 24..31)
|
||||
uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3);
|
||||
uint8x16_t hbit_lo_1 = vandq_u8(qh[cp][1], mone);
|
||||
qh[cp][1] = vshrq_n_u8(qh[cp][1], 2);
|
||||
const int8x16_t qs_lo_1 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1, 4));
|
||||
acc_0 = vmmlaq_s32(acc_0, qs_lo_1, q8s[0][1]);
|
||||
acc_2 = vmmlaq_s32(acc_2, qs_lo_1, q8s[1][1]);
|
||||
const int8x16_t qs_hi_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1));
|
||||
acc_1 = vmmlaq_s32(acc_1, qs_hi_1, q8s[0][5]);
|
||||
acc_3 = vmmlaq_s32(acc_3, qs_hi_1, q8s[1][5]);
|
||||
|
||||
uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3);
|
||||
uint8x16_t hbit_lo_2 = vandq_u8(qh[cp][2], mone);
|
||||
qh[cp][2] = vshrq_n_u8(qh[cp][2], 2);
|
||||
const int8x16_t qs_lo_2 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2, 4));
|
||||
acc_0 = vmmlaq_s32(acc_0, qs_lo_2, q8s[0][2]);
|
||||
acc_2 = vmmlaq_s32(acc_2, qs_lo_2, q8s[1][2]);
|
||||
const int8x16_t qs_hi_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2));
|
||||
acc_1 = vmmlaq_s32(acc_1, qs_hi_2, q8s[0][6]);
|
||||
acc_3 = vmmlaq_s32(acc_3, qs_hi_2, q8s[1][6]);
|
||||
|
||||
uint8x16_t hbit_lo_3 = vandq_u8(qh[cp][3], mone);
|
||||
uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3);
|
||||
qh[cp][3] = vshrq_n_u8(qh[cp][3], 2);
|
||||
const int8x16_t qs_lo_3 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3, 4));
|
||||
acc_0 = vmmlaq_s32(acc_0, qs_lo_3, q8s[0][3]);
|
||||
sb_acc[0] = acc_0;
|
||||
acc_2 = vmmlaq_s32(acc_2, qs_lo_3, q8s[1][3]);
|
||||
sb_acc[2] = acc_2;
|
||||
|
||||
// Scales[i] corresponds to column i
|
||||
const int scale_offset = cp * 2;
|
||||
const int32_t s0 = q5sb_scales[0][scale_offset];
|
||||
const int32_t s1 = q5sb_scales[0][scale_offset + 1];
|
||||
const int32x4_t block_scale = vcombine_s32(vdup_n_s32(s0), vdup_n_s32(s1));
|
||||
acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale);
|
||||
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale);
|
||||
|
||||
const int8x16_t qs_hi_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3));
|
||||
acc_1 = vmmlaq_s32(acc_1, qs_hi_3, q8s[0][7]);
|
||||
sb_acc[1] = acc_1;
|
||||
acc_3 = vmmlaq_s32(acc_3, qs_hi_3, q8s[1][7]);
|
||||
sb_acc[3] = acc_3;
|
||||
|
||||
const int32_t s2 = q5sb_scales[1][scale_offset];
|
||||
const int32_t s3 = q5sb_scales[1][scale_offset + 1];
|
||||
const int32x4_t block_scale2 = vcombine_s32(vdup_n_s32(s2), vdup_n_s32(s3));
|
||||
acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale2);
|
||||
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale2);
|
||||
}
|
||||
|
||||
// Multiply Acc bsum + mins
|
||||
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||
// Each pair of subblocks share the same bsums
|
||||
// Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
|
||||
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
|
||||
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
|
||||
|
||||
bias_acc[2 * q8_row] =
|
||||
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
|
||||
bias_acc[2 * q8_row] =
|
||||
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
|
||||
bias_acc[2 * q8_row + 1] =
|
||||
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
|
||||
bias_acc[2 * q8_row + 1] =
|
||||
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
|
||||
}
|
||||
} // for sb
|
||||
|
||||
// Reorder of i8mm output with bias and output layout
|
||||
for (int i = 0; i < 8; i++) {
|
||||
int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
|
||||
acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
|
||||
}
|
||||
int32x4_t reorder_acc[8] = {
|
||||
vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
|
||||
vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
|
||||
vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
|
||||
vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
|
||||
vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
|
||||
vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
|
||||
vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
|
||||
vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
|
||||
};
|
||||
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
|
||||
float32x4_t q5_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].dmin + j * 4)));
|
||||
const float32x4_t dmins = vmulq_f32(q5_dmin, q8_d);
|
||||
|
||||
float32x4_t q5_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].d + j * 4)));
|
||||
const float32x4_t scale = vmulq_f32(q5_d, q8_d);
|
||||
|
||||
acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
|
||||
acc_f32[2 * i + j] =
|
||||
vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
|
||||
}
|
||||
}
|
||||
} // for b
|
||||
|
||||
// With the previous reorder, the tile is already in the correct memory layout.
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
int row = y * q8_k_blocklen + i;
|
||||
for (int j = 0; j < 2; j++) {
|
||||
int col = x * ncols_interleaved + j * 4;
|
||||
int offset = row * bs + col;
|
||||
vst1q_f32(s + offset, acc_f32[2 * i + j]);
|
||||
}
|
||||
}
|
||||
} // for x
|
||||
} // for y
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q8_0_4x4_q8_0(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
|
|
|
|||
|
|
@ -474,15 +474,8 @@ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||
assert (n % qk == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
float sumf[8];
|
||||
float sum_minf[8];
|
||||
|
|
@ -616,6 +609,100 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_q5_K_8x8_q8_K_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
float sumf[8];
|
||||
float sum_minf[8];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[j] = 0.0;
|
||||
sum_minf[j] = 0.0;
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
|
||||
|
||||
const int qh_shift = (k / 4) * 2;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
|
||||
const int qh_idx = (k * 8 + i) % 32;
|
||||
const int qh_chunk = qh_idx / 8;
|
||||
const int qh_pos = qh_idx % 8;
|
||||
const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
|
||||
|
||||
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
|
||||
const uint8_t h0 = (qh_val >> qh_shift) & 1;
|
||||
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
|
||||
|
||||
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
|
||||
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
|
||||
|
||||
const int q8_offset = (k >> 2) * 64 + (k % 4) * blocklen + i;
|
||||
|
||||
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) *
|
||||
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
|
@ -1212,6 +1299,108 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_q5_K_8x8_q8_K_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
|
||||
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
||||
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
||||
constexpr uint32_t kmask3 = 0x03030303;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
float sumf[4][8];
|
||||
float sum_minf[4][8];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[m][j] = 0.0;
|
||||
sum_minf[m][j] = 0.0;
|
||||
}
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
|
||||
|
||||
const int qh_shift = (k / 4) * 2;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
|
||||
const int qh_idx = (k * 8 + i) % 32;
|
||||
const int qh_chunk = qh_idx / 8;
|
||||
const int qh_pos = qh_idx % 8;
|
||||
const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
|
||||
|
||||
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
|
||||
const uint8_t h0 = (qh_val >> qh_shift) & 1;
|
||||
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
|
||||
|
||||
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
|
||||
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
|
||||
|
||||
const int q8_offset = (k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i;
|
||||
|
||||
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) *
|
||||
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
|
|
@ -1622,7 +1811,95 @@ static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_in
|
|||
out.scales[i] = in[src1].scales[src2];
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_interleave) {
|
||||
block_q5_Kx8 out;
|
||||
//Delta(scale) and dmin values of the eight Q5_K structures are copied onto the output interleaved structure
|
||||
for (int i = 0; i < 8; i++) {
|
||||
out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
|
||||
}
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
|
||||
}
|
||||
|
||||
const int end = QK_K * 4 / blck_size_interleave;
|
||||
|
||||
// Interleave Q5_K quants by taking 8 bytes at a time
|
||||
for (int i = 0; i < end; ++i) {
|
||||
int src_id = i % 8;
|
||||
int src_offset = (i / 8) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
uint64_t elems;
|
||||
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
|
||||
}
|
||||
|
||||
// Repeat for low bits 8 bytes at a time as well, since
|
||||
// the high bits are interleaved in Q5_K and the index is
|
||||
// qh_idx = (qs_idx % 32);
|
||||
// qh_val = qh[qh_idx] >> (qs_idx / 32);
|
||||
for (int i = 0; i < end / 4; ++i) {
|
||||
int src_id = i % 8;
|
||||
int src_offset = (i / 8) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
uint64_t elems;
|
||||
memcpy(&elems, &in[src_id].qh[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.qh[dst_offset], &elems, sizeof(uint64_t));
|
||||
}
|
||||
|
||||
// The below logic is copied over from Q4_K
|
||||
// The point is to unpack all the scales and mins for each sub block every time we load 12 bytes.
|
||||
// Currently the Q5_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)
|
||||
// The output Q5_Kx8 structure has 96 bytes
|
||||
// Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q5_K structure
|
||||
// For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q5_K structures
|
||||
uint8_t s[8], m[8];
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (int j = 0; j < 8; j++) {
|
||||
s[j] = in[j].scales[i] & 63;
|
||||
m[j] = in[j].scales[i + 4] & 63;
|
||||
}
|
||||
|
||||
out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2);
|
||||
out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2);
|
||||
out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2);
|
||||
out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2);
|
||||
out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2);
|
||||
out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2);
|
||||
out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2);
|
||||
out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2);
|
||||
out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4);
|
||||
out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4);
|
||||
out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);
|
||||
out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);
|
||||
}
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (int j = 0; j < 8; j++) {
|
||||
s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i + 8] & 15);
|
||||
m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i + 8] & 240) >> 4);
|
||||
}
|
||||
|
||||
out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);
|
||||
out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);
|
||||
out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);
|
||||
out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);
|
||||
out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);
|
||||
out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);
|
||||
out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);
|
||||
out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);
|
||||
out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);
|
||||
out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);
|
||||
out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);
|
||||
out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
|
|
@ -1718,6 +1995,38 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block
|
|||
GGML_UNUSED(data_size);
|
||||
}
|
||||
|
||||
static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t,
|
||||
int interleave_block,
|
||||
const void * GGML_RESTRICT data,
|
||||
size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_Q5_K);
|
||||
GGML_ASSERT(interleave_block == 8);
|
||||
constexpr int nrows_interleaved = 8;
|
||||
|
||||
block_q5_Kx8 * dst = (block_q5_Kx8 *) t->data;
|
||||
const block_q5_K * src = (const block_q5_K *) data;
|
||||
block_q5_K dst_tmp[8];
|
||||
int nrow = ggml_nrows(t);
|
||||
int nblocks = t->ne[0] / QK_K;
|
||||
|
||||
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_K));
|
||||
|
||||
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
||||
for (int64_t x = 0; x < nblocks; x++) {
|
||||
for (int i = 0; i < nrows_interleaved; i++) {
|
||||
dst_tmp[i] = src[x + i * nblocks];
|
||||
}
|
||||
*dst++ = make_block_q5_Kx8(dst_tmp, interleave_block);
|
||||
}
|
||||
src += nrows_interleaved * nblocks;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
|
||||
GGML_ASSERT(interleave_block == 8);
|
||||
|
|
@ -1936,6 +2245,10 @@ template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * da
|
|||
return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q5_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
|
||||
}
|
||||
|
|
@ -1973,6 +2286,10 @@ template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t
|
|||
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
|
@ -1981,8 +2298,8 @@ template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
|
|||
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
template <> void gemv<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
|
|
@ -2013,20 +2330,24 @@ template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t
|
|||
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
template <> void gemm<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
|
|
@ -2432,6 +2753,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
|||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, GGML_TYPE_Q8_K> q4_K_8x4_q8_K;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
|
||||
|
||||
// instance for Q5_K
|
||||
static const ggml::cpu::repack::tensor_traits<block_q5_K, 8, 8, GGML_TYPE_Q8_K> q5_K_8x8_q8_K;
|
||||
|
||||
// instance for Q2
|
||||
static const ggml::cpu::repack::tensor_traits<block_q2_K, 8, 8, GGML_TYPE_Q8_K> q2_K_8x8_q8_K;
|
||||
|
||||
|
|
@ -2482,6 +2806,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
|||
return &q2_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_Q5_K) {
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
return &q5_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_IQ4_NL) {
|
||||
if (ggml_cpu_has_avx2()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ struct block_q4_Kx8 {
|
|||
};
|
||||
|
||||
static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
|
||||
|
||||
struct block_q2_Kx8 {
|
||||
ggml_half d[8]; // super-block scale for quantized scales
|
||||
ggml_half dmin[8]; // super-block scale for quantized mins
|
||||
|
|
@ -52,6 +53,18 @@ struct block_q2_Kx8 {
|
|||
};
|
||||
|
||||
static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding");
|
||||
|
||||
struct block_q5_Kx8 {
|
||||
ggml_half d[8]; // super-block scale for quantized scales
|
||||
ggml_half dmin[8]; // super-block scale for quantized mins
|
||||
uint8_t scales[96]; // scales and mins, quantized with 6 bits
|
||||
uint8_t qh[QK_K * 8 / 8]; // high bits of 5-bit quants
|
||||
uint8_t qs[QK_K * 8 / 2]; // low bits of 5-bit quants (in groups of 4)
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5,
|
||||
"wrong q5_K block size/padding");
|
||||
|
||||
struct block_q8_Kx4 {
|
||||
float d[4]; // delta
|
||||
int8_t qs[QK_K * 4]; // quants
|
||||
|
|
@ -82,20 +95,22 @@ void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTR
|
|||
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
|
@ -111,17 +126,19 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG
|
|||
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
|
|
|||
|
|
@ -1327,10 +1327,44 @@ struct ggml_backend_cuda_context {
|
|||
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
|
||||
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
||||
|
||||
std::unique_ptr<ggml_cuda_graph> cuda_graph;
|
||||
|
||||
int curr_stream_no = 0;
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
// Map from first_node_ptr to cuda_graph - allows multiple graphs per context
|
||||
// when the computation is split across CPU/GPU (e.g., with --n-cpu-moe)
|
||||
std::unordered_map<const void *, std::unique_ptr<ggml_cuda_graph>> cuda_graphs;
|
||||
|
||||
ggml_cuda_graph * cuda_graph(const void * first_node_ptr) {
|
||||
auto it = cuda_graphs.find(first_node_ptr);
|
||||
if (it == cuda_graphs.end()) {
|
||||
cuda_graphs[first_node_ptr] = std::make_unique<ggml_cuda_graph>();
|
||||
return cuda_graphs[first_node_ptr].get();
|
||||
}
|
||||
return it->second.get();
|
||||
}
|
||||
|
||||
// Check if any CUDA graph is enabled for this context (used by kernels that need to know
|
||||
// if graphs are in use without having access to the specific graph key)
|
||||
bool any_cuda_graph_enabled() const {
|
||||
for (const auto & [key, graph] : cuda_graphs) {
|
||||
if (graph && graph->is_enabled()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if any CUDA graph has an instance for this context
|
||||
bool any_cuda_graph_has_instance() const {
|
||||
for (const auto & [key, graph] : cuda_graphs) {
|
||||
if (graph && graph->instance != nullptr) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
explicit ggml_backend_cuda_context(int device) :
|
||||
device(device),
|
||||
name(GGML_CUDA_NAME + std::to_string(device)) {
|
||||
|
|
|
|||
|
|
@ -778,13 +778,11 @@ void launch_fattn(
|
|||
) {
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
|
||||
const bool is_mla = DV == 512; // TODO better parameterization
|
||||
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
GGML_ASSERT(V || is_mla);
|
||||
const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data;
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * sinks = dst->src[4];
|
||||
|
|
@ -794,9 +792,9 @@ void launch_fattn(
|
|||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT( Q->nb[0] == ggml_element_size(Q));
|
||||
GGML_ASSERT( K->nb[0] == ggml_element_size(K));
|
||||
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
|
||||
GGML_ASSERT(Q->nb[0] == ggml_element_size(Q));
|
||||
GGML_ASSERT(K->nb[0] == ggml_element_size(K));
|
||||
GGML_ASSERT(V->nb[0] == ggml_element_size(V));
|
||||
|
||||
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||
|
||||
|
|
@ -817,10 +815,10 @@ void launch_fattn(
|
|||
size_t nb12 = K->nb[2];
|
||||
size_t nb13 = K->nb[3];
|
||||
|
||||
const char * V_data = V ? (const char *) V->data : nullptr;
|
||||
size_t nb21 = V ? V->nb[1] : nb11;
|
||||
size_t nb22 = V ? V->nb[2] : nb12;
|
||||
size_t nb23 = V ? V->nb[3] : nb13;
|
||||
const char * V_data = (const char *) V->data;
|
||||
size_t nb21 = V->nb[1];
|
||||
size_t nb22 = V->nb[2];
|
||||
size_t nb23 = V->nb[3];
|
||||
|
||||
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
||||
const size_t bs = ggml_blck_size(K->type);
|
||||
|
|
@ -849,32 +847,39 @@ void launch_fattn(
|
|||
K_data = (char *) K_f16.ptr;
|
||||
}
|
||||
|
||||
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
|
||||
const size_t bs = ggml_blck_size(V->type);
|
||||
const size_t ts = ggml_type_size(V->type);
|
||||
|
||||
V_f16.alloc(ggml_nelements(V));
|
||||
if (ggml_is_contiguously_allocated(V)) {
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
||||
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
||||
V_data = (char *) V_f16.ptr;
|
||||
|
||||
nb21 = nb21*bs*sizeof(half)/ts;
|
||||
nb22 = nb22*bs*sizeof(half)/ts;
|
||||
nb23 = nb23*bs*sizeof(half)/ts;
|
||||
if (need_f16_V && V->type != GGML_TYPE_F16) {
|
||||
if (V_is_K_view) {
|
||||
V_data = K_data;
|
||||
nb21 = nb11;
|
||||
nb22 = nb12;
|
||||
nb23 = nb13;
|
||||
} else {
|
||||
GGML_ASSERT(V->nb[0] == ts);
|
||||
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
|
||||
const int64_t s01 = nb21 / ts;
|
||||
const int64_t s02 = nb22 / ts;
|
||||
const int64_t s03 = nb23 / ts;
|
||||
to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
|
||||
const size_t bs = ggml_blck_size(V->type);
|
||||
const size_t ts = ggml_type_size(V->type);
|
||||
|
||||
nb21 = V->ne[0] * sizeof(half);
|
||||
nb22 = V->ne[1] * nb21;
|
||||
nb23 = V->ne[2] * nb22;
|
||||
V_f16.alloc(ggml_nelements(V));
|
||||
if (ggml_is_contiguously_allocated(V)) {
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
||||
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
||||
V_data = (char *) V_f16.ptr;
|
||||
|
||||
nb21 = nb21*bs*sizeof(half)/ts;
|
||||
nb22 = nb22*bs*sizeof(half)/ts;
|
||||
nb23 = nb23*bs*sizeof(half)/ts;
|
||||
} else {
|
||||
GGML_ASSERT(V->nb[0] == ts);
|
||||
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
|
||||
const int64_t s01 = nb21 / ts;
|
||||
const int64_t s02 = nb22 / ts;
|
||||
const int64_t s03 = nb23 / ts;
|
||||
to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
|
||||
|
||||
nb21 = V->ne[0] * sizeof(half);
|
||||
nb22 = V->ne[1] * nb21;
|
||||
nb23 = V->ne[2] * nb22;
|
||||
}
|
||||
V_data = (char *) V_f16.ptr;
|
||||
}
|
||||
V_data = (char *) V_f16.ptr;
|
||||
}
|
||||
|
||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||
|
|
|
|||
|
|
@ -400,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|||
}
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps,
|
||||
bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
|
||||
bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
|
||||
typename T_A_KQ, typename T_B_KQ, typename T_C_KQ, typename T_A_VKQ, typename T_B_VKQ, typename T_C_VKQ>
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
const float2 * const __restrict__ Q_f2,
|
||||
|
|
@ -442,8 +442,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
constexpr int stride_tile_Q = DKQ/2 + 4;
|
||||
constexpr int stride_tile_K = nbatch_K2 + 4;
|
||||
|
||||
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
||||
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
||||
constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
|
||||
|
||||
const int k_VKQ_0 = kb0 * nbatch_fa;
|
||||
#if defined(TURING_MMA_AVAILABLE)
|
||||
|
|
@ -456,7 +455,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
|
||||
if constexpr (nstages > 1) {
|
||||
static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline");
|
||||
static_assert(!mla, "multi-stage loading not implemented for MLA");
|
||||
static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
|
||||
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
|
||||
constexpr bool use_cp_async = true;
|
||||
cp_async_wait_all();
|
||||
|
|
@ -471,8 +470,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
}
|
||||
}
|
||||
|
||||
// For MLA K and V have the same data.
|
||||
// Therefore, iterate over K in reverse and later re-use the data if possible.
|
||||
#pragma unroll
|
||||
for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
|
||||
for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) {
|
||||
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
|
||||
const int k0_diff = k0_stop - k0_start;
|
||||
|
||||
|
|
@ -776,6 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
}
|
||||
|
||||
if constexpr (nstages > 1) {
|
||||
static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
|
||||
// Preload K tile for next iteration:
|
||||
constexpr bool use_cp_async = true;
|
||||
cp_async_wait_all();
|
||||
|
|
@ -791,10 +793,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
}
|
||||
|
||||
|
||||
// For MLA K and V have the same data.
|
||||
// Therefore, iterate over V in reverse and re-use the data if possible.
|
||||
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
|
||||
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
|
||||
#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
|
||||
T_A_VKQ A_identity;
|
||||
make_identity_mat(A_identity);
|
||||
|
|
@ -802,12 +800,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
|
||||
// Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
|
||||
#pragma unroll
|
||||
for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
|
||||
const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
|
||||
const int i0_diff = i0_stop - i0_start;
|
||||
for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) {
|
||||
static_assert(DV % (2*nbatch_V2) == 0, "bad loop size");
|
||||
const int i0_stop = i0_start + 2*nbatch_V2;
|
||||
const int i0_diff = i0_stop - i0_start;
|
||||
|
||||
if constexpr (nstages <= 1) {
|
||||
if (i0_start < reusable_cutoff) {
|
||||
if (!V_is_K_view || i0_stop > 2*nbatch_K2) {
|
||||
constexpr bool use_cp_async = nstages == 1;
|
||||
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
|
||||
(V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup);
|
||||
|
|
@ -817,7 +816,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
__syncthreads();
|
||||
}
|
||||
}
|
||||
const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
|
||||
const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
|
||||
|
||||
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
|
||||
|
|
@ -920,7 +919,7 @@ template<int ncols> struct mma_tile_sizes {
|
|||
};
|
||||
#endif // defined(TURING_MMA_AVAILABLE)
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup>
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const float2 * const __restrict__ Q_f2,
|
||||
const half2 * const __restrict__ K_h2,
|
||||
|
|
@ -974,8 +973,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||
constexpr int stride_tile_Q = DKQ/2 + 4;
|
||||
constexpr int stride_tile_K = nbatch_K2 + 4;
|
||||
|
||||
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
||||
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
||||
constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
|
||||
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
|
||||
|
||||
extern __shared__ half2 tile_Q[];
|
||||
|
|
@ -1079,7 +1077,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||
constexpr bool last_iter = false;
|
||||
constexpr int k_VKQ_sup = nbatch_fa;
|
||||
flash_attn_ext_f16_iter
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
||||
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
||||
|
|
@ -1088,7 +1086,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||
constexpr bool last_iter = true;
|
||||
const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
|
||||
flash_attn_ext_f16_iter
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
||||
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
||||
|
|
@ -1099,7 +1097,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||
constexpr bool last_iter = false;
|
||||
constexpr int k_VKQ_sup = nbatch_fa;
|
||||
flash_attn_ext_f16_iter
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
||||
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
||||
|
|
@ -1108,7 +1106,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||
constexpr bool last_iter = true;
|
||||
constexpr int k_VKQ_sup = nbatch_fa;
|
||||
flash_attn_ext_f16_iter
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
||||
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
||||
|
|
@ -1456,7 +1454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
||||
}
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool mla>
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view>
|
||||
__launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2))
|
||||
static __global__ void flash_attn_ext_f16(
|
||||
const char * __restrict__ Q,
|
||||
|
|
@ -1508,8 +1506,6 @@ static __global__ void flash_attn_ext_f16(
|
|||
}
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
|
||||
static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
|
||||
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
||||
constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
|
||||
|
|
@ -1522,7 +1518,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int stride_K = nb11 / sizeof(half2);
|
||||
const int stride_mask = nb31 / sizeof(half);
|
||||
|
||||
const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
|
||||
const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);
|
||||
|
||||
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
||||
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
|
||||
|
|
@ -1552,7 +1548,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
(const half *) (mask + nb33*(sequence % ne33));
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
||||
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
||||
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
|
||||
|
|
@ -1563,12 +1559,12 @@ static __global__ void flash_attn_ext_f16(
|
|||
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
||||
if (kb0_start == 0) {
|
||||
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
||||
} else {
|
||||
constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
||||
}
|
||||
|
|
@ -1596,7 +1592,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
(const half *) (mask + nb33*(sequence % ne33));
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
||||
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
||||
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
|
||||
|
|
@ -1607,7 +1603,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||
constexpr bool needs_fixup = false;
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
||||
#else
|
||||
|
|
@ -1643,7 +1639,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|||
const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
|
||||
constexpr bool mla = DKQ == 576;
|
||||
constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu
|
||||
|
||||
const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
|
||||
|
|
@ -1668,7 +1664,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|||
fattn_kernel_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
|
||||
|
||||
#if !defined(GGML_USE_MUSA)
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||
|
|
@ -1679,7 +1675,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|||
#endif // !defined(GGML_USE_MUSA)
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
|
||||
|
||||
#if !defined(GGML_USE_MUSA)
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
|
|||
// are put into the template specialization without GQA optimizations.
|
||||
bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
for (const ggml_tensor * t : {Q, K, V, mask}) {
|
||||
if (t == nullptr) {
|
||||
if (t == nullptr || ggml_is_quantized(t->type)) {
|
||||
continue;
|
||||
}
|
||||
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
|
||||
|
|
@ -236,7 +236,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
// The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
|
||||
bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
for (const ggml_tensor * t : {Q, K, V, mask}) {
|
||||
if (t == nullptr) {
|
||||
if (t == nullptr || ggml_is_quantized(t->type)) {
|
||||
continue;
|
||||
}
|
||||
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
|
||||
|
|
@ -247,6 +247,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
}
|
||||
}
|
||||
|
||||
const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data;
|
||||
|
||||
const int cc = ggml_cuda_info().devices[device].cc;
|
||||
|
||||
switch (K->ne[0]) {
|
||||
|
|
@ -269,6 +271,9 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!V_is_K_view) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
|
|
|
|||
|
|
@ -2969,18 +2969,25 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
|
|||
return true;
|
||||
}
|
||||
|
||||
static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
|
||||
return cgraph->nodes[0];
|
||||
}
|
||||
|
||||
static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
|
||||
|
||||
bool res = false;
|
||||
|
||||
if (cuda_ctx->cuda_graph->instance == nullptr) {
|
||||
const void * graph_key = ggml_cuda_graph_get_key(cgraph);
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
|
||||
if (graph->instance == nullptr) {
|
||||
res = true;
|
||||
}
|
||||
|
||||
// Check if the graph size has changed
|
||||
if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
|
||||
if (graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
|
||||
res = true;
|
||||
cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
|
||||
graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
|
||||
}
|
||||
|
||||
// Loop over nodes in GGML graph to determine if CUDA graph update is required
|
||||
|
|
@ -2988,37 +2995,38 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
|
|||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
bool props_match = true;
|
||||
if (!res) {
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]);
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
|
||||
}
|
||||
if (!props_match) {
|
||||
res = true;
|
||||
}
|
||||
ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]);
|
||||
ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
|
||||
}
|
||||
|
||||
for (int i = 0; i < cgraph->n_leafs; i++) {
|
||||
bool props_match= true;
|
||||
bool props_match = true;
|
||||
if (!res) {
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]);
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &graph->props[cgraph->n_nodes + i]);
|
||||
}
|
||||
if (!props_match) {
|
||||
res = true;
|
||||
}
|
||||
ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
|
||||
ggml_cuda_graph_node_set_properties(&graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||
static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
|
||||
#if CUDART_VERSION >= 12000
|
||||
cudaGraphExecUpdateResultInfo result_info;
|
||||
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
||||
cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &result_info);
|
||||
#else
|
||||
cudaGraphNode_t errorNode;
|
||||
cudaGraphExecUpdateResult result_info;
|
||||
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
|
||||
cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &errorNode, &result_info);
|
||||
#endif // CUDART_VERSION >= 12000
|
||||
|
||||
if (stat == cudaErrorGraphExecUpdateFailure) {
|
||||
|
|
@ -3029,14 +3037,14 @@ static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_c
|
|||
// The pre-existing graph exec cannot be updated due to violated constraints
|
||||
// so instead clear error and re-instantiate
|
||||
(void)cudaGetLastError();
|
||||
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
|
||||
cuda_ctx->cuda_graph->instance = nullptr;
|
||||
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||
CUDA_CHECK(cudaGraphExecDestroy(graph->instance));
|
||||
graph->instance = nullptr;
|
||||
CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
|
||||
} else {
|
||||
GGML_ASSERT(stat == cudaSuccess);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
|
||||
const ggml_tensor * view,
|
||||
|
|
@ -3241,7 +3249,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||
return false;
|
||||
}
|
||||
|
||||
static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) {
|
||||
static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
|
||||
bool graph_evaluated_or_captured = false;
|
||||
|
||||
// flag used to determine whether it is an integrated_gpu
|
||||
|
|
@ -3695,13 +3703,14 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
|||
}
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
|
||||
if (cuda_ctx->cuda_graph->graph != nullptr) {
|
||||
CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
|
||||
cuda_ctx->cuda_graph->graph = nullptr;
|
||||
if (graph->graph != nullptr) {
|
||||
CUDA_CHECK(cudaGraphDestroy(graph->graph));
|
||||
graph->graph = nullptr;
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
|
||||
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph->graph));
|
||||
graph_evaluated_or_captured = true; // CUDA graph has been captured
|
||||
|
||||
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
|
||||
|
|
@ -3714,40 +3723,39 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
|||
}
|
||||
|
||||
if (use_cuda_graph) {
|
||||
if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
|
||||
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
if (graph->instance == nullptr) { // Create executable graph from captured graph.
|
||||
CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
|
||||
}
|
||||
if (cuda_graph_update_required) { // Update graph executable
|
||||
ggml_cuda_graph_update_executable(cuda_ctx);
|
||||
ggml_cuda_graph_update_executable(cuda_ctx, graph_key);
|
||||
}
|
||||
// Launch graph
|
||||
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
|
||||
CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream()));
|
||||
#else
|
||||
graph_evaluated_or_captured = true;
|
||||
#endif // USE_CUDA_GRAPH
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) {
|
||||
static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
|
||||
if (cuda_ctx->cuda_graph == nullptr) {
|
||||
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
|
||||
}
|
||||
|
||||
if (cuda_ctx->cuda_graph->graph == nullptr) {
|
||||
if (graph->graph == nullptr) {
|
||||
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
|
||||
if (!cuda_ctx->cuda_graph->disable_due_to_gpu_arch) {
|
||||
if (!graph->disable_due_to_gpu_arch) {
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
|
||||
}
|
||||
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
|
||||
graph->disable_due_to_gpu_arch = true;
|
||||
}
|
||||
}
|
||||
|
||||
return cuda_ctx->cuda_graph->is_enabled();
|
||||
return graph->is_enabled();
|
||||
#else
|
||||
GGML_UNUSED(cuda_ctx);
|
||||
GGML_UNUSED(graph_key);
|
||||
return false;
|
||||
#endif // USE_CUDA_GRAPH
|
||||
}
|
||||
|
|
@ -3759,15 +3767,19 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
|||
|
||||
bool use_cuda_graph = false;
|
||||
bool cuda_graph_update_required = false;
|
||||
const void * graph_key = nullptr;
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
|
||||
graph_key = ggml_cuda_graph_get_key(cgraph);
|
||||
|
||||
if (cuda_ctx->cuda_graph->is_enabled()) {
|
||||
use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
|
||||
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
if (graph->is_enabled()) {
|
||||
cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
|
||||
use_cuda_graph = ggml_cuda_graph_check_compability(cgraph);
|
||||
|
||||
cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required);
|
||||
graph->record_update(use_cuda_graph, cuda_graph_update_required);
|
||||
}
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
|
|
@ -3781,7 +3793,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
|||
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
|
||||
}
|
||||
|
||||
ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required);
|
||||
ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key);
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
|
@ -3814,7 +3826,14 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
|
|||
static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
|
||||
|
||||
const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
const void * graph_key = ggml_cuda_graph_get_key(cgraph);
|
||||
const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
|
||||
#else
|
||||
const bool use_cuda_graph = false;
|
||||
GGML_UNUSED(cuda_ctx);
|
||||
GGML_UNUSED(cgraph);
|
||||
#endif
|
||||
|
||||
static bool enable_graph_optimization = [] {
|
||||
const char * env = getenv("GGML_CUDA_GRAPH_OPT");
|
||||
|
|
|
|||
|
|
@ -31,14 +31,15 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
#endif // USE_CUDA_GRAPH
|
||||
if ((nrows == 1) &&
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
// CUDA_GRAPHS_DISABLED
|
||||
((ncols > 65536) &&
|
||||
((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
|
||||
ctx.cuda_graph->is_enabled())) ||
|
||||
// CUDA_GRAPHS ENABLED
|
||||
((ncols > 32768) &&
|
||||
!((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
|
||||
ctx.cuda_graph->is_enabled()))) {
|
||||
// Determine if CUDA graphs are effectively disabled for this context
|
||||
// (no graph instance exists and we're not capturing, OR graphs are explicitly enabled)
|
||||
(((ncols > 65536) &&
|
||||
(((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) ||
|
||||
ctx.any_cuda_graph_enabled())) ||
|
||||
// CUDA graphs are enabled - use lower threshold
|
||||
((ncols > 32768) &&
|
||||
!(((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) ||
|
||||
ctx.any_cuda_graph_enabled())))) {
|
||||
#else
|
||||
(ncols > 65536)) {
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#include <assert.h>
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_perf.h>
|
||||
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
|
|
@ -111,7 +111,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
|
|||
hvx_vec_store_u(r, 4, rsum);
|
||||
}
|
||||
|
||||
// MAD: y (F32) += x (F16) * v (float)
|
||||
// MAD: y (F32) += x (F16) * s (float)
|
||||
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
|
||||
const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
|
||||
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
|
||||
|
|
@ -318,9 +318,12 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|||
uint32_t ic = 0;
|
||||
|
||||
// Process in blocks of 32 (VLEN_FP32)
|
||||
for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) {
|
||||
static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 == 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
|
||||
HVX_Vector_x4 scores_x4;
|
||||
HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
|
||||
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
|
||||
// 1. Compute scores
|
||||
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
|
||||
float __attribute__((aligned(VLEN))) scores_arr[FLASH_ATTN_BLOCK_SIZE];
|
||||
for (int j = 0; j < VLEN_FP32; ++j) {
|
||||
const uint32_t cur_ic = ic + j;
|
||||
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
|
||||
|
|
@ -356,36 +359,43 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
}
|
||||
|
||||
// 4. Online Softmax Update
|
||||
HVX_Vector v_max = hvx_vec_reduce_max_f32(scores);
|
||||
float m_block = hvx_vec_get_f32(v_max);
|
||||
scores_x4.v[iv] = scores;
|
||||
v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max);
|
||||
}
|
||||
|
||||
{
|
||||
// 4. Online Softmax Update
|
||||
v_max = hvx_vec_reduce_max_f32(v_max);
|
||||
float m_block = hvx_vec_get_f32(v_max);
|
||||
float M_old = M;
|
||||
float M_new = (m_block > M) ? m_block : M;
|
||||
M = M_new;
|
||||
|
||||
float ms = expf(M_old - M_new);
|
||||
|
||||
const float ms = expf(M_old - M_new);
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
S = S * ms;
|
||||
|
||||
HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new);
|
||||
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
|
||||
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
|
||||
HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
|
||||
for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
|
||||
HVX_Vector scores = scores_x4.v[iv];
|
||||
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
|
||||
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
|
||||
|
||||
HVX_Vector p_sum_vec = hvx_vec_reduce_sum_f32(P);
|
||||
float p_sum = hvx_vec_get_f32(p_sum_vec);
|
||||
S += p_sum;
|
||||
p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
|
||||
|
||||
// 5. Accumulate V
|
||||
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
|
||||
*(HVX_Vector*)p_arr = P;
|
||||
// 5. Accumulate V
|
||||
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
|
||||
*(HVX_Vector*)p_arr = P;
|
||||
|
||||
for (int j = 0; j < VLEN_FP32; ++j) {
|
||||
const uint32_t cur_ic = ic + j;
|
||||
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
|
||||
for (int j = 0; j < VLEN_FP32; ++j) {
|
||||
const uint32_t cur_ic = ic2 + j;
|
||||
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
|
||||
}
|
||||
}
|
||||
|
||||
p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
|
||||
S = S * ms + hvx_vec_get_f32(p_sum_vec);
|
||||
}
|
||||
|
||||
// Leftover
|
||||
|
|
|
|||
|
|
@ -398,6 +398,7 @@ struct ggml_backend_opencl_context {
|
|||
int adreno_wave_size;
|
||||
|
||||
cl_bool non_uniform_workgroups;
|
||||
size_t image_max_buffer_size;
|
||||
|
||||
cl_context context;
|
||||
cl_command_queue queue;
|
||||
|
|
@ -407,6 +408,10 @@ struct ggml_backend_opencl_context {
|
|||
ggml_cl_buffer prealloc_scales_trans;
|
||||
ggml_cl_buffer prealloc_act_trans;
|
||||
|
||||
// prealloc buffers for src0 and src1
|
||||
ggml_cl_buffer prealloc_src0;
|
||||
ggml_cl_buffer prealloc_src1;
|
||||
|
||||
cl_program program_add;
|
||||
cl_program program_add_id;
|
||||
cl_program program_clamp;
|
||||
|
|
@ -2658,6 +2663,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
|||
clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL);
|
||||
GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024);
|
||||
|
||||
clGetDeviceInfo(device, CL_DEVICE_IMAGE_MAX_BUFFER_SIZE, sizeof(size_t), &backend_ctx->image_max_buffer_size, NULL);
|
||||
GGML_LOG_INFO("ggml_opencl: device max image buffer size (pixels): %lu\n", backend_ctx->image_max_buffer_size);
|
||||
|
||||
clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL);
|
||||
GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", backend_ctx->max_workgroup_size);
|
||||
|
||||
|
|
@ -4711,6 +4719,81 @@ static bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct gg
|
|||
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32);
|
||||
}
|
||||
|
||||
// Copy a noncontiguous tensor to contiguous tensor. ne[] remains the same but
|
||||
// nb[] is recalculated such that tensor is contiguous.
|
||||
static void ggml_cl_copy_to_contiguous(ggml_backend_t backend, const ggml_tensor * src, cl_mem dst,
|
||||
cl_ulong &nb0, cl_ulong &nb1, cl_ulong &nb2, cl_ulong &nb3) {
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
const int tensor_type_size = ggml_type_size(src->type);
|
||||
|
||||
const int ne00 = src->ne[0];
|
||||
const int ne01 = src->ne[1];
|
||||
const int ne02 = src->ne[2];
|
||||
const int ne03 = src->ne[3];
|
||||
|
||||
const cl_ulong nb00 = src->nb[0];
|
||||
const cl_ulong nb01 = src->nb[1];
|
||||
const cl_ulong nb02 = src->nb[2];
|
||||
const cl_ulong nb03 = src->nb[3];
|
||||
|
||||
const int ne0 = src->ne[0];
|
||||
const int ne1 = src->ne[1];
|
||||
const int ne2 = src->ne[2];
|
||||
const int ne3 = src->ne[3];
|
||||
|
||||
nb0 = tensor_type_size;
|
||||
nb1 = tensor_type_size*ne00;
|
||||
nb2 = tensor_type_size*ne00*ne01;
|
||||
nb3 = tensor_type_size*ne00*ne01*ne02;
|
||||
|
||||
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *)src->extra;
|
||||
|
||||
cl_ulong offset0 = extra->offset + src->view_offs;
|
||||
cl_ulong offsetd = 0;
|
||||
|
||||
cl_kernel kernel;
|
||||
|
||||
switch (src->type) {
|
||||
case GGML_TYPE_F32:
|
||||
kernel = backend_ctx->kernel_cpy_f32_f32;
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
kernel = backend_ctx->kernel_cpy_f16_f16;
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "not implemented");
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &dst));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne3));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3));
|
||||
|
||||
const int nth = MIN(64, ne00);
|
||||
|
||||
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
||||
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, src);
|
||||
}
|
||||
|
||||
static void ggml_cl_nop(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
UNUSED(backend);
|
||||
UNUSED(src0);
|
||||
|
|
@ -7724,9 +7807,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
|||
cl_context context = backend_ctx->context;
|
||||
|
||||
if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){
|
||||
if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0) {
|
||||
if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0 &&
|
||||
// dst is wrapped with image1d_buffer, the size limit applies, also src0
|
||||
(ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4 <= backend_ctx->image_max_buffer_size)) {
|
||||
// For KQ
|
||||
if (ggml_is_permuted(src0) && ggml_is_permuted(src1) &&
|
||||
((nb01 * ne01 / 4)/4 <= backend_ctx->image_max_buffer_size) &&
|
||||
nb00 <= nb02 &&
|
||||
nb02 <= nb01 &&
|
||||
nb01 <= nb03 &&
|
||||
|
|
@ -7737,7 +7823,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
|||
return;
|
||||
}
|
||||
// For KQV
|
||||
if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
|
||||
((nb02 * ne02 / 4)/4 <= backend_ctx->image_max_buffer_size)) {
|
||||
ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
|
||||
return;
|
||||
}
|
||||
|
|
@ -8043,9 +8130,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
|||
|
||||
// GEMM using local memory
|
||||
// Current BK = 16, so ne00 % 16 == 0
|
||||
if (ggml_is_contiguous(src0) &&
|
||||
ggml_is_contiguous(src1) &&
|
||||
src1t == GGML_TYPE_F32 &&
|
||||
if (src1t == GGML_TYPE_F32 &&
|
||||
ne00 % 16 == 0 &&
|
||||
ne11 > 1) {
|
||||
switch(src0t) {
|
||||
|
|
@ -8057,10 +8142,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
|||
int batch_stride_b = ne10*ne11;
|
||||
int batch_stride_d = ne0*ne1;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
cl_mem mem_src0 = extra0->data_device;
|
||||
cl_mem mem_src1 = extra1->data_device;
|
||||
|
||||
cl_ulong nb00_cont = nb00;
|
||||
cl_ulong nb01_cont = nb01;
|
||||
cl_ulong nb02_cont = nb02;
|
||||
cl_ulong nb03_cont = nb03;
|
||||
|
||||
cl_ulong nb10_cont = nb10;
|
||||
cl_ulong nb11_cont = nb11;
|
||||
cl_ulong nb12_cont = nb12;
|
||||
cl_ulong nb13_cont = nb13;
|
||||
|
||||
cl_ulong offset0_cont = offset0;
|
||||
cl_ulong offset1_cont = offset1;
|
||||
|
||||
if (!ggml_is_contiguous(src0)) {
|
||||
backend_ctx->prealloc_src0.allocate(backend_ctx->context, ggml_nbytes(src0));
|
||||
ggml_cl_copy_to_contiguous(backend, src0, backend_ctx->prealloc_src0.buffer,
|
||||
nb00_cont, nb01_cont, nb02_cont, nb03_cont);
|
||||
mem_src0 = backend_ctx->prealloc_src0.buffer;
|
||||
offset0_cont = 0;
|
||||
}
|
||||
|
||||
if (!ggml_is_contiguous(src1)) {
|
||||
backend_ctx->prealloc_src1.allocate(backend_ctx->context, ggml_nbytes(src1));
|
||||
ggml_cl_copy_to_contiguous(backend, src1, backend_ctx->prealloc_src1.buffer,
|
||||
nb10_cont, nb11_cont, nb12_cont, nb13_cont);
|
||||
mem_src1 = backend_ctx->prealloc_src1.buffer;
|
||||
offset1_cont = 0;
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &mem_src0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_cont));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &mem_src1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1_cont));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||
|
|
@ -8092,10 +8209,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
|||
int batch_stride_b = ne10*ne11;
|
||||
int batch_stride_d = ne0*ne1;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
cl_mem mem_src0 = extra0->data_device;
|
||||
cl_mem mem_src1 = extra1->data_device;
|
||||
|
||||
cl_ulong nb00_cont = nb00;
|
||||
cl_ulong nb01_cont = nb01;
|
||||
cl_ulong nb02_cont = nb02;
|
||||
cl_ulong nb03_cont = nb03;
|
||||
|
||||
cl_ulong nb10_cont = nb10;
|
||||
cl_ulong nb11_cont = nb11;
|
||||
cl_ulong nb12_cont = nb12;
|
||||
cl_ulong nb13_cont = nb13;
|
||||
|
||||
cl_ulong offset0_cont = offset0;
|
||||
cl_ulong offset1_cont = offset1;
|
||||
|
||||
if (!ggml_is_contiguous(src0)) {
|
||||
backend_ctx->prealloc_src0.allocate(backend_ctx->context, ggml_nbytes(src0));
|
||||
ggml_cl_copy_to_contiguous(backend, src0, backend_ctx->prealloc_src0.buffer,
|
||||
nb00_cont, nb01_cont, nb02_cont, nb03_cont);
|
||||
mem_src0 = backend_ctx->prealloc_src0.buffer;
|
||||
offset0_cont = 0;
|
||||
}
|
||||
|
||||
if (!ggml_is_contiguous(src1)) {
|
||||
backend_ctx->prealloc_src1.allocate(backend_ctx->context, ggml_nbytes(src1));
|
||||
ggml_cl_copy_to_contiguous(backend, src1, backend_ctx->prealloc_src1.buffer,
|
||||
nb10_cont, nb11_cont, nb12_cont, nb13_cont);
|
||||
mem_src1 = backend_ctx->prealloc_src1.buffer;
|
||||
offset1_cont = 0;
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &mem_src0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_cont));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &mem_src1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1_cont));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||
|
|
@ -8123,6 +8272,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
|||
if (ne11 < 32) {
|
||||
break;
|
||||
}
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
|
||||
break;
|
||||
}
|
||||
|
||||
kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm;
|
||||
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
|
||||
|
||||
|
|
|
|||
|
|
@ -1157,13 +1157,28 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_
|
|||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
inline void * aligned_malloc_host(size_t alignment, size_t size) {
|
||||
#ifdef _WIN32
|
||||
return _aligned_malloc(size, alignment);
|
||||
#else
|
||||
return aligned_alloc(alignment, size);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void free_aligned_mem_host(void * memblock) {
|
||||
#ifdef _WIN32
|
||||
_aligned_free(memblock);
|
||||
#else
|
||||
free(memblock);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
ggml_sycl_host_free(buffer->context);
|
||||
free_aligned_mem_host((void *)buffer->context);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||
void * ptr = ggml_sycl_host_malloc(size);
|
||||
|
||||
void * ptr = aligned_malloc_host(TENSOR_ALIGNMENT, size);
|
||||
if (ptr == nullptr) {
|
||||
// fallback to cpu buffer
|
||||
return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ LLAMA_BENCH_DB_FIELDS = [
|
|||
"cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers",
|
||||
"split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides",
|
||||
"use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth",
|
||||
"test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts",
|
||||
"test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", "n_cpu_moe"
|
||||
]
|
||||
|
||||
LLAMA_BENCH_DB_TYPES = [
|
||||
|
|
@ -38,7 +38,7 @@ LLAMA_BENCH_DB_TYPES = [
|
|||
"TEXT", "INTEGER", "INTEGER", "TEXT", "TEXT", "INTEGER",
|
||||
"TEXT", "INTEGER", "INTEGER", "INTEGER", "TEXT", "TEXT",
|
||||
"INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
|
||||
"TEXT", "INTEGER", "INTEGER", "REAL", "REAL",
|
||||
"TEXT", "INTEGER", "INTEGER", "REAL", "REAL", "INTEGER",
|
||||
]
|
||||
|
||||
# All test-backend-ops SQL fields
|
||||
|
|
@ -59,7 +59,7 @@ assert len(TEST_BACKEND_OPS_DB_FIELDS) == len(TEST_BACKEND_OPS_DB_TYPES)
|
|||
|
||||
# Properties by which to differentiate results per commit for llama-bench:
|
||||
LLAMA_BENCH_KEY_PROPERTIES = [
|
||||
"cpu_info", "gpu_info", "backends", "n_gpu_layers", "tensor_buft_overrides", "model_filename", "model_type",
|
||||
"cpu_info", "gpu_info", "backends", "n_gpu_layers", "n_cpu_moe", "tensor_buft_overrides", "model_filename", "model_type",
|
||||
"n_batch", "n_ubatch", "embeddings", "cpu_mask", "cpu_strict", "poll", "n_threads", "type_k", "type_v",
|
||||
"use_mmap", "no_kv_offload", "split_mode", "main_gpu", "tensor_split", "flash_attn", "n_prompt", "n_gen", "n_depth"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -2903,7 +2903,7 @@ void llama_context::opt_epoch_iter(
|
|||
};
|
||||
ctx_compute_opt = ggml_init(params);
|
||||
}
|
||||
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
|
||||
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_inp_tokens(), res->get_logits());
|
||||
ggml_opt_alloc(opt_ctx, train);
|
||||
|
||||
res->set_inputs(&ubatch);
|
||||
|
|
|
|||
|
|
@ -23,7 +23,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
|||
}
|
||||
|
||||
if (ubatch->embd) {
|
||||
const int64_t n_embd = embd->ne[0];
|
||||
GGML_ASSERT(n_embd == embd->ne[0]);
|
||||
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
|
||||
|
|
@ -33,8 +34,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
|||
bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
|
||||
bool res = true;
|
||||
|
||||
res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
|
||||
res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
|
||||
res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
|
||||
res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
|
@ -634,7 +635,8 @@ int64_t llm_graph_result::get_max_nodes() const {
|
|||
}
|
||||
|
||||
void llm_graph_result::reset() {
|
||||
t_tokens = nullptr;
|
||||
t_inp_tokens = nullptr;
|
||||
t_inp_embd = nullptr;
|
||||
t_logits = nullptr;
|
||||
t_embd = nullptr;
|
||||
t_embd_pooled = nullptr;
|
||||
|
|
@ -1338,17 +1340,29 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
|
||||
// input embeddings with optional lora
|
||||
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
||||
const int64_t n_embd = hparams.n_embd_inp();
|
||||
const int64_t n_embd_inp = hparams.n_embd_inp();
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_embd>();
|
||||
assert(n_embd_inp >= n_embd);
|
||||
|
||||
ggml_tensor * cur = nullptr;
|
||||
auto inp = std::make_unique<llm_graph_input_embd>(n_embd_inp);
|
||||
|
||||
if (ubatch.token) {
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
//cb(inp->tokens, "inp_tokens", -1);
|
||||
ggml_set_input(inp->tokens);
|
||||
res->t_tokens = inp->tokens;
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
cb(inp->tokens, "inp_tokens", -1);
|
||||
ggml_set_input(inp->tokens);
|
||||
res->t_inp_tokens = inp->tokens;
|
||||
|
||||
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens);
|
||||
cb(inp->embd, "inp_embd", -1);
|
||||
ggml_set_input(inp->embd);
|
||||
|
||||
// select one of the 2 inputs, based on the batch contents
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18550
|
||||
std::array<ggml_tensor *, 2> inps;
|
||||
|
||||
// token embeddings path (ubatch.token != nullptr)
|
||||
{
|
||||
auto & cur = inps[0];
|
||||
|
||||
cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
|
||||
|
||||
|
|
@ -1369,19 +1383,36 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
|||
|
||||
cur = ggml_add(ctx0, cur, inpL_delta);
|
||||
}
|
||||
} else {
|
||||
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
|
||||
ggml_set_input(inp->embd);
|
||||
|
||||
if (n_embd_inp != n_embd) {
|
||||
cur = ggml_pad(ctx0, cur, hparams.n_embd_inp() - n_embd, 0, 0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
// vector embeddings path (ubatch.embd != nullptr)
|
||||
{
|
||||
auto & cur = inps[1];
|
||||
|
||||
cur = inp->embd;
|
||||
}
|
||||
|
||||
assert(ggml_are_same_shape (inps[0], inps[1]));
|
||||
assert(ggml_are_same_stride(inps[0], inps[1]));
|
||||
|
||||
ggml_tensor * cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1);
|
||||
|
||||
if (n_embd_inp != n_embd) {
|
||||
cur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0);
|
||||
}
|
||||
|
||||
res->t_inp_embd = cur;
|
||||
|
||||
// For Granite architecture
|
||||
if (hparams.f_embedding_scale != 0.0f) {
|
||||
cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
|
||||
}
|
||||
|
||||
cb(cur, "inp_embd", -1);
|
||||
cb(cur, "embd", -1);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
|
||||
|
|
@ -1480,7 +1511,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
|
|||
//}
|
||||
|
||||
const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
|
||||
const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
||||
const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
||||
|
||||
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
|
||||
ggml_set_input(cur);
|
||||
|
|
@ -1565,6 +1596,11 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||
v = ggml_transpose(ctx0, v);
|
||||
}
|
||||
|
||||
// TODO: update llama_kv_cache to not store V cache in the MLA case and automatically return a view of K
|
||||
if (v_mla) {
|
||||
v = ggml_view_4d(ctx0, k, v->ne[0], v->ne[1], v->ne[2], v->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
|
||||
}
|
||||
|
||||
// this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
|
||||
if (k->type == GGML_TYPE_F32) {
|
||||
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
|
|||
|
||||
class llm_graph_input_embd : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_embd() = default;
|
||||
llm_graph_input_embd(int64_t n_embd) : n_embd(n_embd) {}
|
||||
virtual ~llm_graph_input_embd() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
|
@ -115,6 +115,8 @@ public:
|
|||
|
||||
ggml_tensor * tokens = nullptr; // I32 [n_batch]
|
||||
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
|
||||
|
||||
const int64_t n_embd = 0;
|
||||
};
|
||||
|
||||
class llm_graph_input_pos : public llm_graph_input_i {
|
||||
|
|
@ -566,7 +568,7 @@ public:
|
|||
|
||||
virtual ~llm_graph_result() = default;
|
||||
|
||||
ggml_tensor * get_tokens() const { return t_tokens; }
|
||||
ggml_tensor * get_inp_tokens() const { return t_inp_tokens; }
|
||||
ggml_tensor * get_logits() const { return t_logits; }
|
||||
ggml_tensor * get_embd() const { return t_embd; }
|
||||
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
|
||||
|
|
@ -593,7 +595,8 @@ public:
|
|||
void set_params(const llm_graph_params & params);
|
||||
|
||||
// important graph nodes
|
||||
ggml_tensor * t_tokens = nullptr;
|
||||
ggml_tensor * t_inp_tokens = nullptr;
|
||||
ggml_tensor * t_inp_embd = nullptr; // [n_embd_inp, n_tokens]
|
||||
ggml_tensor * t_logits = nullptr;
|
||||
ggml_tensor * t_embd = nullptr;
|
||||
ggml_tensor * t_embd_pooled = nullptr;
|
||||
|
|
|
|||
|
|
@ -1594,6 +1594,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
|||
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
||||
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
||||
|
||||
const auto & n_rot = hparams.n_rot;
|
||||
|
||||
const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0;
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
||||
|
||||
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
|
||||
|
|
@ -1614,10 +1618,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
|||
|
||||
ggml_tensor * k =
|
||||
ggml_view_3d(ctx, layer.k,
|
||||
n_embd_head_k, n_head_kv, get_size()*n_stream,
|
||||
n_rot, n_head_kv, get_size()*n_stream,
|
||||
ggml_row_size(layer.k->type, n_embd_head_k),
|
||||
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
||||
0);
|
||||
ggml_row_size(layer.k->type, n_embd_nope));
|
||||
|
||||
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
|
||||
|
||||
|
|
|
|||
|
|
@ -124,14 +124,14 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
|||
|
||||
// {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
|
||||
// note: rope must go first for in-place context shifting in build_rope_shift()
|
||||
ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);
|
||||
ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
|
||||
cb(kv_cmpr, "kv_cmpr_reshape", il);
|
||||
|
||||
// {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0);
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
// {kv_lora_rank, 1, n_tokens}
|
||||
|
|
@ -169,11 +169,10 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
|||
Vcur = ggml_cont(ctx0, Vcur);
|
||||
cb(Vcur, "Vcur_cont", il);
|
||||
|
||||
// note: rope must go first for in-place context shifting in build_rope_shift()
|
||||
ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0);
|
||||
ggml_tensor * Qcur = ggml_concat(ctx0, q_nope, q_pe, 0);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0);
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
if (inp_attn_scale) {
|
||||
|
|
|
|||
|
|
@ -245,12 +245,12 @@ ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) {
|
|||
// equivalent to get_per_layer_inputs() in python code
|
||||
// output shape: [n_embd_altup, n_layer, n_tokens]
|
||||
ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
|
||||
auto inp = std::make_unique<llm_graph_input_embd>();
|
||||
auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
|
||||
ggml_tensor * inp_per_layer;
|
||||
if (ubatch.token) {
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
ggml_set_input(inp->tokens);
|
||||
res->t_tokens = inp->tokens;
|
||||
res->t_inp_tokens = inp->tokens;
|
||||
inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
|
||||
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
|
||||
inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup));
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap
|
|||
|
||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
|
||||
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params &
|
|||
|
||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
|
||||
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@
|
|||
|
||||
llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const size_t n_deepstack_layers = hparams.n_deepstack_layers;
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
|
@ -16,17 +17,6 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_
|
|||
int sections[4];
|
||||
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||
|
||||
std::vector<ggml_tensor *> deepstack_features(n_deepstack_layers, nullptr);
|
||||
|
||||
if (ubatch.embd) {
|
||||
// Image input: split main embd and deepstack embds
|
||||
ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0);
|
||||
for (size_t i = 0; i < n_deepstack_layers; i++) {
|
||||
deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float));
|
||||
}
|
||||
inpL = inpL_main;
|
||||
}
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
|
|
@ -120,8 +110,9 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_
|
|||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
if (ubatch.embd && (size_t)il < n_deepstack_layers) {
|
||||
cur = ggml_add(ctx0, cur, deepstack_features[il]);
|
||||
if (il < (int) n_deepstack_layers) {
|
||||
ggml_tensor * ds = ggml_view_2d(ctx0, res->t_inp_embd, n_embd, n_tokens, res->t_inp_embd->nb[1], (il + 1) * n_embd * sizeof(float));
|
||||
cur = ggml_add(ctx0, cur, ds);
|
||||
cb(cur, "deepstack_out", il);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@
|
|||
|
||||
llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const size_t n_deepstack_layers = hparams.n_deepstack_layers;
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
|
@ -16,17 +17,6 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_
|
|||
int sections[4];
|
||||
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||
|
||||
std::vector<ggml_tensor *> deepstack_features(n_deepstack_layers, nullptr);
|
||||
|
||||
if (ubatch.embd) {
|
||||
// Image input: split main embd and deepstack embds
|
||||
ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0);
|
||||
for (size_t i = 0; i < n_deepstack_layers; i++) {
|
||||
deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float));
|
||||
}
|
||||
inpL = inpL_main;
|
||||
}
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
|
|
@ -113,8 +103,9 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_
|
|||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
if (ubatch.embd && (size_t)il < n_deepstack_layers) {
|
||||
cur = ggml_add(ctx0, cur, deepstack_features[il]);
|
||||
if (il < (int) n_deepstack_layers) {
|
||||
ggml_tensor * ds = ggml_view_2d(ctx0, res->t_inp_embd, n_embd, n_tokens, res->t_inp_embd->nb[1], (il + 1) * n_embd * sizeof(float));
|
||||
cur = ggml_add(ctx0, cur, ds);
|
||||
cb(cur, "deepstack_out", il);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6122,7 +6122,19 @@ struct test_flash_attn_ext : public test_case {
|
|||
ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1], true); // the K tensor is usually a view of the K cache
|
||||
ggml_set_name(k, "k");
|
||||
|
||||
ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache
|
||||
ggml_tensor * v = nullptr;
|
||||
if (hsk_padded == 576 && hsv_padded == 512) {
|
||||
// TODO: this branch should become a separate test case parameter instead of hardcoding this for these head shapes
|
||||
|
||||
// in this branch, the V cache is sub-view of the K cache. this is used by some MLA-based models
|
||||
// for more info:
|
||||
// - https://github.com/ggml-org/llama.cpp/pull/13435
|
||||
// - https://github.com/ggml-org/llama.cpp/pull/18953#issuecomment-3774948392
|
||||
// - https://github.com/ggml-org/llama.cpp/pull/18986
|
||||
v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0);
|
||||
} else {
|
||||
v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache
|
||||
}
|
||||
ggml_set_name(v, "v");
|
||||
|
||||
ggml_tensor * m = nullptr;
|
||||
|
|
|
|||
|
|
@ -462,9 +462,9 @@ static void test_parser_with_streaming(const common_chat_msg & expected, const s
|
|||
for (size_t i = 1; i <= raw_message.size(); ++i) {
|
||||
auto curr_msg = parse_msg(std::string(utf8_truncate_safe_view(std::string_view(raw_message).substr(0, i))));
|
||||
if (curr_msg == simple_assist_msg("")) continue;
|
||||
LOG_INF("Streaming msg: %s\n", common_chat_msgs_to_json_oaicompat<json>({curr_msg}).dump().c_str());
|
||||
LOG_INF("Streaming msg: %s\n", common_chat_msgs_to_json_oaicompat({curr_msg}).dump().c_str());
|
||||
for (auto diff: common_chat_msg_diff::compute_diffs(last_msg, curr_msg)) {
|
||||
LOG_INF("Streaming diff: %s\n", common_chat_msg_diff_to_json_oaicompat<json>(diff).dump().c_str());
|
||||
LOG_INF("Streaming diff: %s\n", common_chat_msg_diff_to_json_oaicompat(diff).dump().c_str());
|
||||
if (!diff.reasoning_content_delta.empty()) {
|
||||
merged.reasoning_content += diff.reasoning_content_delta;
|
||||
}
|
||||
|
|
@ -480,7 +480,7 @@ static void test_parser_with_streaming(const common_chat_msg & expected, const s
|
|||
merged.tool_calls.back().arguments += diff.tool_call_delta.arguments;
|
||||
}
|
||||
}
|
||||
LOG_INF("Streaming merged: %s\n", common_chat_msgs_to_json_oaicompat<json>({merged}).dump().c_str());
|
||||
LOG_INF("Streaming merged: %s\n", common_chat_msgs_to_json_oaicompat({merged}).dump().c_str());
|
||||
}
|
||||
assert_msg_equals(curr_msg, merged, true);
|
||||
last_msg = curr_msg;
|
||||
|
|
@ -622,7 +622,7 @@ static void test_msgs_oaicompat_json_conversion() {
|
|||
message_assist_call_code_interpreter,
|
||||
};
|
||||
for (const auto & msg : msgs) {
|
||||
auto oai_json = common_chat_msgs_to_json_oaicompat<json>({msg});
|
||||
auto oai_json = common_chat_msgs_to_json_oaicompat({msg});
|
||||
auto msgs2 = common_chat_msgs_parse_oaicompat(oai_json);
|
||||
assert_equals((size_t) 1, msgs2.size());
|
||||
auto msg2 = msgs2[0];
|
||||
|
|
@ -646,7 +646,7 @@ static void test_msgs_oaicompat_json_conversion() {
|
|||
" }\n"
|
||||
"]"
|
||||
),
|
||||
common_chat_msgs_to_json_oaicompat<json>({message_user_parts}).dump(2));
|
||||
common_chat_msgs_to_json_oaicompat({message_user_parts}).dump(2));
|
||||
|
||||
assert_equals(
|
||||
std::string(
|
||||
|
|
@ -666,7 +666,7 @@ static void test_msgs_oaicompat_json_conversion() {
|
|||
" }\n"
|
||||
"]"
|
||||
),
|
||||
common_chat_msgs_to_json_oaicompat<json>({message_assist_call_python}).dump(2));
|
||||
common_chat_msgs_to_json_oaicompat({message_assist_call_python}).dump(2));
|
||||
|
||||
auto res = common_chat_msgs_parse_oaicompat(json::parse("[{\"role\": \"assistant\", \"tool_calls\": []}]"));
|
||||
assert_equals<size_t>(1, res.size());
|
||||
|
|
@ -693,7 +693,7 @@ static void test_tools_oaicompat_json_conversion() {
|
|||
};
|
||||
|
||||
for (const auto & tool : tools) {
|
||||
auto oai_json = common_chat_tools_to_json_oaicompat<json>({tool});
|
||||
auto oai_json = common_chat_tools_to_json_oaicompat({tool});
|
||||
auto tools2 = common_chat_tools_parse_oaicompat(oai_json);
|
||||
assert_equals((size_t) 1, tools2.size());
|
||||
auto tool2 = tools2[0];
|
||||
|
|
@ -726,7 +726,7 @@ static void test_tools_oaicompat_json_conversion() {
|
|||
" }\n"
|
||||
"]"
|
||||
),
|
||||
common_chat_tools_to_json_oaicompat<json>({special_function_tool}).dump(2));
|
||||
common_chat_tools_to_json_oaicompat({special_function_tool}).dump(2));
|
||||
|
||||
{
|
||||
auto tools_no_params = common_chat_tools_parse_oaicompat(json::parse(
|
||||
|
|
|
|||
|
|
@ -84,6 +84,9 @@ struct cli_context {
|
|||
// chat template settings
|
||||
task.params.chat_parser_params = common_chat_parser_params(chat_params);
|
||||
task.params.chat_parser_params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
if (!chat_params.parser.empty()) {
|
||||
task.params.chat_parser_params.parser.load(chat_params.parser);
|
||||
}
|
||||
|
||||
rd.post_task({std::move(task)});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -781,6 +781,7 @@ By default, it is read-only. To make POST request to change global properties, y
|
|||
"total_slots": 1,
|
||||
"model_path": "../models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf",
|
||||
"chat_template": "...",
|
||||
"chat_template_caps": {},
|
||||
"modalities": {
|
||||
"vision": false
|
||||
},
|
||||
|
|
@ -793,6 +794,7 @@ By default, it is read-only. To make POST request to change global properties, y
|
|||
- `total_slots` - the total number of slots for process requests (defined by `--parallel` option)
|
||||
- `model_path` - the path to model file (same with `-m` argument)
|
||||
- `chat_template` - the model's original Jinja2 prompt template
|
||||
- `chat_template_caps` - capabilities of the chat template (see `common/jinja/caps.h` for more info)
|
||||
- `modalities` - the list of supported modalities
|
||||
- `is_sleeping` - sleeping status, see [Sleeping on idle](#sleeping-on-idle)
|
||||
|
||||
|
|
@ -1267,6 +1269,12 @@ This provides information on the performance of the server. It also allows calcu
|
|||
|
||||
The total number of tokens in context is equal to `prompt_n + cache_n + predicted_n`
|
||||
|
||||
*Reasoning support*
|
||||
|
||||
The server supports parsing and returning reasoning via the `reasoning_content` field, similar to Deepseek API.
|
||||
|
||||
Reasoning input (preserve reasoning in history) is also supported by some specific templates. For more details, please refer to [PR#18994](https://github.com/ggml-org/llama.cpp/pull/18994).
|
||||
|
||||
### POST `/v1/responses`: OpenAI-compatible Responses API
|
||||
|
||||
*Options:*
|
||||
|
|
|
|||
|
|
@ -2903,6 +2903,7 @@ server_context_meta server_context::get_meta() const {
|
|||
/* pooling_type */ llama_pooling_type(impl->ctx),
|
||||
|
||||
/* chat_params */ impl->chat_params,
|
||||
/* chat_template_caps */ common_chat_templates_get_caps(impl->chat_params.tmpls.get()),
|
||||
|
||||
/* bos_token_str */ bos_token_str,
|
||||
/* eos_token_str */ eos_token_str,
|
||||
|
|
@ -3410,6 +3411,7 @@ void server_routes::init_routes() {
|
|||
{ "webui", params.webui },
|
||||
{ "webui_settings", meta->json_webui_settings },
|
||||
{ "chat_template", tmpl_default },
|
||||
{ "chat_template_caps", meta->chat_template_caps },
|
||||
{ "bos_token", meta->bos_token_str },
|
||||
{ "eos_token", meta->eos_token_str },
|
||||
{ "build_info", meta->build_info },
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ struct server_context_meta {
|
|||
|
||||
// chat params
|
||||
server_chat_params & chat_params;
|
||||
std::map<std::string, bool> chat_template_caps;
|
||||
|
||||
// tokens
|
||||
std::string bos_token_str;
|
||||
|
|
|
|||
|
|
@ -28,14 +28,20 @@ server_http_context::server_http_context()
|
|||
server_http_context::~server_http_context() = default;
|
||||
|
||||
static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
|
||||
// skip GH copilot requests when using default port
|
||||
if (req.path == "/v1/health") {
|
||||
// skip logging requests that are regularly sent, to avoid log spam
|
||||
if (req.path == "/health"
|
||||
|| req.path == "/v1/health"
|
||||
|| req.path == "/models"
|
||||
|| req.path == "/v1/models"
|
||||
|| req.path == "/props"
|
||||
|| req.path == "/metrics"
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
// reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
|
||||
|
||||
SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
|
||||
SRV_INF("done request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
|
||||
|
||||
SRV_DBG("request: %s\n", req.body.c_str());
|
||||
SRV_DBG("response: %s\n", res.body.c_str());
|
||||
|
|
|
|||
|
|
@ -700,7 +700,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat() {
|
|||
json choice {
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", index},
|
||||
{"message", msg.to_json_oaicompat<json>()},
|
||||
{"message", msg.to_json_oaicompat()},
|
||||
};
|
||||
|
||||
if (!stream && probs_output.size() > 0) {
|
||||
|
|
@ -750,7 +750,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
|
|||
json {
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", common_chat_msg_diff_to_json_oaicompat<json>(diff)},
|
||||
{"delta", common_chat_msg_diff_to_json_oaicompat(diff)},
|
||||
},
|
||||
})},
|
||||
{"created", t},
|
||||
|
|
@ -1383,7 +1383,7 @@ json server_task_result_cmpl_partial::to_json_oaicompat_chat() {
|
|||
}
|
||||
|
||||
for (const auto & diff : oaicompat_msg_diffs) {
|
||||
add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
|
||||
add_delta(common_chat_msg_diff_to_json_oaicompat(diff));
|
||||
}
|
||||
|
||||
if (!deltas.empty()) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue