diff --git a/.devops/vulkan.Dockerfile b/.devops/vulkan.Dockerfile
index fcd81ffa1e..6cf87c67e8 100644
--- a/.devops/vulkan.Dockerfile
+++ b/.devops/vulkan.Dockerfile
@@ -2,14 +2,30 @@ ARG UBUNTU_VERSION=24.04
FROM ubuntu:$UBUNTU_VERSION AS build
-# Install build tools
-RUN apt update && apt install -y git build-essential cmake wget
+# Ref: https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html
-# Install Vulkan SDK and cURL
-RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \
- wget -qO /etc/apt/sources.list.d/lunarg-vulkan-noble.list https://packages.lunarg.com/vulkan/lunarg-vulkan-noble.list && \
- apt update -y && \
- apt-get install -y vulkan-sdk libcurl4-openssl-dev curl
+# Install build tools
+RUN apt update && apt install -y git build-essential cmake wget xz-utils
+
+# Install Vulkan SDK
+ARG VULKAN_VERSION=1.4.321.1
+RUN ARCH=$(uname -m) && \
+ wget -qO /tmp/vulkan-sdk.tar.xz https://sdk.lunarg.com/sdk/download/${VULKAN_VERSION}/linux/vulkan-sdk-linux-${ARCH}-${VULKAN_VERSION}.tar.xz && \
+ mkdir -p /opt/vulkan && \
+ tar -xf /tmp/vulkan-sdk.tar.xz -C /tmp --strip-components=1 && \
+ mv /tmp/${ARCH}/* /opt/vulkan/ && \
+ rm -rf /tmp/*
+
+# Install cURL and Vulkan SDK dependencies
+RUN apt install -y libcurl4-openssl-dev curl \
+ libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev
+
+# Set environment variables
+ENV VULKAN_SDK=/opt/vulkan
+ENV PATH=$VULKAN_SDK/bin:$PATH
+ENV LD_LIBRARY_PATH=$VULKAN_SDK/lib:$LD_LIBRARY_PATH
+ENV CMAKE_PREFIX_PATH=$VULKAN_SDK:$CMAKE_PREFIX_PATH
+ENV PKG_CONFIG_PATH=$VULKAN_SDK/lib/pkgconfig:$PKG_CONFIG_PATH
# Build it
WORKDIR /app
diff --git a/README.md b/README.md
index 8446756384..a01ef6d503 100644
--- a/README.md
+++ b/README.md
@@ -151,6 +151,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [x] [Bunny](https://github.com/BAAI-DCAI/Bunny)
- [x] [GLM-EDGE](https://huggingface.co/models?search=glm-edge)
- [x] [Qwen2-VL](https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d)
+- [x] [LFM2-VL](https://huggingface.co/collections/LiquidAI/lfm2-vl-68963bbc84a610f7638d5ffa)
diff --git a/common/arg.cpp b/common/arg.cpp
index 91c2229ebd..469650491b 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -1106,7 +1106,7 @@ static void common_params_print_completion(common_params_context & ctx_arg) {
printf("\"\n\n");
printf(" case \"$prev\" in\n");
- printf(" --model)\n");
+ printf(" --model|-m)\n");
printf(" COMPREPLY=( $(compgen -f -X '!*.gguf' -- \"$cur\") $(compgen -d -- \"$cur\") )\n");
printf(" return 0\n");
printf(" ;;\n");
@@ -1755,7 +1755,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.warmup = false;
}
- ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL}));
+ ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY}));
add_opt(common_arg(
{"--spm-infill"},
string_format(
@@ -2254,9 +2254,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
add_opt(common_arg(
{"-dt", "--defrag-thold"}, "N",
- string_format("KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold),
+ string_format("KV cache defragmentation threshold (DEPRECATED)"),
[](common_params & params, const std::string & value) {
- params.defrag_thold = std::stof(value);
+ GGML_UNUSED(params);
+ GGML_UNUSED(value);
+ LOG_WRN("DEPRECATED: --defrag-thold is deprecated and no longer necessary to specify\n");
}
).set_env("LLAMA_ARG_DEFRAG_THOLD"));
add_opt(common_arg(
@@ -2553,7 +2555,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--lora"}, "FNAME",
"path to LoRA adapter (can be repeated to use multiple adapters)",
[](common_params & params, const std::string & value) {
- params.lora_adapters.push_back({ std::string(value), 1.0, nullptr });
+ params.lora_adapters.push_back({ std::string(value), 1.0, "", "", nullptr });
}
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
@@ -2561,7 +2563,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--lora-scaled"}, "FNAME", "SCALE",
"path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)",
[](common_params & params, const std::string & fname, const std::string & scale) {
- params.lora_adapters.push_back({ fname, std::stof(scale), nullptr });
+ params.lora_adapters.push_back({ fname, std::stof(scale), "", "", nullptr });
}
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
@@ -3543,6 +3545,22 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
+ add_opt(common_arg(
+ {"--fim-qwen-30b-default"},
+ string_format("use default Qwen 3 Coder 30B A3B Instruct (note: can download weights from the internet)"),
+ [](common_params & params) {
+ params.model.hf_repo = "ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF";
+ params.model.hf_file = "qwen3-coder-30b-a3b-instruct-q8_0.gguf";
+ params.port = 8012;
+ params.n_gpu_layers = 99;
+ params.flash_attn = true;
+ params.n_ubatch = 1024;
+ params.n_batch = 1024;
+ params.n_ctx = 0;
+ params.n_cache_reuse = 256;
+ }
+ ).set_examples({LLAMA_EXAMPLE_SERVER}));
+
add_opt(common_arg(
{ "--diffusion-steps" }, "N",
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
diff --git a/common/chat.cpp b/common/chat.cpp
index 7f6809a4ed..955c42852a 100644
--- a/common/chat.cpp
+++ b/common/chat.cpp
@@ -622,6 +622,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
+ case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
default:
throw std::runtime_error("Unknown chat format");
}
@@ -1361,6 +1362,26 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
"<|end|>",
};
+ if (!inputs.json_schema.is_null()) {
+ data.grammar_lazy = false;
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
+ auto schema = inputs.json_schema;
+ builder.resolve_refs(schema);
+
+ auto not_end = builder.add_rule("not-end",
+ "[^<] | \"<\" [^|] | \"<|\" [^e] | \"<|e\" [^n] | \"<|en\" [^d] | \"<|end\" [^|] | \"<|end|\" [^>]");
+ auto analysis = builder.add_rule("analysis",
+ "\"<|channel|>analysis<|message|>\" ( " + not_end + " )* \"<|end|>\"");
+ auto constraint = builder.add_rule("constraint", "\"<|constrain|>\"? [a-zA-Z0-9_-]+");
+ auto final = builder.add_rule("final",
+ "\"<|channel|>final\" ( \" \" " + constraint + " )? \"<|message|>\" " +
+ builder.add_schema("response", schema)
+ );
+
+ builder.add_rule("root", "( " + analysis + " \"<|start|>assistant\" )? " + final);
+ });
+ }
+
if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
@@ -2039,6 +2060,94 @@ static void common_chat_parse_granite(common_chat_msg_parser & builder) {
}
}
+static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
+ // Parse thinking tags first - this handles the main reasoning content
+ builder.try_parse_reasoning("", "");
+
+ if (!builder.syntax().parse_tool_calls) {
+ builder.add_content(builder.consume_rest());
+ return;
+ }
+
+ // Parse tool calls - Seed-OSS uses format
+ static const common_regex tool_call_begin_regex("");
+ static const common_regex tool_call_end_regex("");
+ static const common_regex function_regex("]+)>");
+ static const common_regex param_regex("]+)>");
+
+ while (auto tool_res = builder.try_find_regex(tool_call_begin_regex)) {
+ builder.consume_spaces(); // Consume whitespace after
+
+ // Look for function call inside tool call, ignore any content before it
+ if (auto func_res = builder.try_find_regex(function_regex, std::string::npos, false)) {
+ auto function_name = builder.str(func_res->groups[1]);
+
+ // Parse Seed-OSS parameters value
+ json args = json::object();
+ // Parse all parameters
+ while (auto param_res = builder.try_find_regex(param_regex, std::string::npos, false)) {
+ // again, ignore noise around parameters
+ auto param_name = builder.str(param_res->groups[1]);
+ builder.move_to(param_res->groups[0].end);
+ builder.consume_spaces(); // Consume whitespace after parameter
+ auto savedPos = builder.pos();
+ if (auto param_parse = builder.try_find_literal("")) {
+ auto param = param_parse->prelude;
+ builder.move_to(savedPos);
+ try {
+ if (auto param_res = builder.try_consume_json()) {
+ args[param_name] = param_res->json;
+ } else {
+ args[param_name] = param;
+ }
+ } catch (json::exception &) {
+ args[param_name] = param;
+ }
+ } else {
+ throw common_chat_msg_partial_exception("Incomplete tool parameter");
+ }
+ }
+ // Look for closing function tag
+ auto end_func = builder.try_find_literal("");
+ if (end_func) {
+ builder.move_to(end_func->groups[0].end);
+ builder.consume_spaces(); // Consume whitespace after
+
+ // Add the tool call with parsed arguments, but only if we REALLY got the literal
+ auto eaten_fragment = builder.input().substr(end_func->groups[0].begin, end_func->groups[0].end);
+ auto funlen = std::string("").length();
+ if (eaten_fragment.length() >= funlen && eaten_fragment.substr(0, funlen) == std::string("")) {
+ if (!builder.add_tool_call(function_name, "", args.dump())) {
+ throw common_chat_msg_partial_exception("Incomplete tool call");
+ }
+ } else {
+ throw common_chat_msg_partial_exception("Incomplete tool call");
+ }
+ } else {
+ throw common_chat_msg_partial_exception("Incomplete tool call");
+ }
+ // Look for closing tool call tag
+ if (auto end_tool = builder.try_find_regex(tool_call_end_regex, std::string::npos, false)) {
+ builder.move_to(end_tool->groups[0].end);
+ builder.consume_spaces(); // Consume trailing whitespace after tool call
+ } else {
+ throw common_chat_msg_partial_exception("Incomplete tool call");
+ }
+ } else {
+ // No function found - don't consume content here, let it be handled at the end
+ break;
+ }
+ }
+
+ // Consume any remaining whitespace after all tool call processing
+ builder.consume_spaces();
+ auto remaining = builder.consume_rest();
+ // If there's any non-whitespace content remaining, add it as content
+ if (!string_strip(remaining).empty()) {
+ builder.add_content(remaining);
+ }
+}
+
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.prompt = apply(tmpl, inputs);
@@ -2055,8 +2164,62 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
return data;
}
+static common_chat_params common_chat_params_init_seed_oss(
+ const common_chat_template & tmpl,
+ templates_params & params,
+ const common_chat_templates_inputs & inputs)
+{
+ common_chat_params data;
+ data.prompt = apply(tmpl, params);
+ data.format = COMMON_CHAT_FORMAT_SEED_OSS;
+ if (string_ends_with(data.prompt, "")) {
+ if (!inputs.enable_thinking) {
+ data.prompt += "";
+ } else {
+ data.thinking_forced_open = true;
+ }
+ }
+
+ if (params.tools.is_array() && !params.tools.empty()) {
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
+ std::vector tool_rules;
+ foreach_function(params.tools, [&](const json & tool) {
+ const auto & function = tool.at("function");
+ std::string name = function.at("name");
+ auto parameters = function.at("parameters");
+ builder.resolve_refs(parameters);
+
+ // Create rule for Seed-OSS function call format
+ std::string param_rules;
+ if (parameters.contains("properties")) {
+ for (const auto & [key, value] : parameters.at("properties").items()) {
+ param_rules += "\"\"" + builder.add_schema(name + "-arg-" + key, value) +
+ "\"\"";
+ }
+ }
+
+ tool_rules.push_back(builder.add_rule(name + "-call",
+ "\"\" space \"\" space " +
+ param_rules +
+ " \"\" space \"\""));
+ });
+
+ data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "" });
+
+ data.preserved_tokens = {
+ "", "", "", "",
+ "", "",
+ };
+
+ builder.add_rule("root", string_join(tool_rules, " | "));
+ });
+ }
+ return data;
+}
+
static common_chat_params common_chat_templates_apply_jinja(
- const struct common_chat_templates * tmpls,
+ const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs)
{
templates_params params;
@@ -2121,10 +2284,15 @@ static common_chat_params common_chat_templates_apply_jinja(
}
// GPT-OSS
- if (src.find("<|channel|>") != std::string::npos && params.json_schema.is_null()) {
+ if (src.find("<|channel|>") != std::string::npos) {
return common_chat_params_init_gpt_oss(tmpl, params);
}
+ // Seed-OSS
+ if (src.find("") != std::string::npos) {
+ return common_chat_params_init_seed_oss(tmpl, params, inputs);
+ }
+
// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) {
@@ -2283,6 +2451,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_GPT_OSS:
common_chat_parse_gpt_oss(builder);
break;
+ case COMMON_CHAT_FORMAT_SEED_OSS:
+ common_chat_parse_seed_oss(builder);
+ break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
diff --git a/common/chat.h b/common/chat.h
index d1e480c918..b09ff3b126 100644
--- a/common/chat.h
+++ b/common/chat.h
@@ -111,6 +111,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_GRANITE,
COMMON_CHAT_FORMAT_GPT_OSS,
+ COMMON_CHAT_FORMAT_SEED_OSS,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
diff --git a/common/common.cpp b/common/common.cpp
index decabcc2ed..054b43be77 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -988,7 +988,12 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams;
}
+ char buf[1024];
la.ptr = lora.get();
+ llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf));
+ la.task_name = buf;
+ llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
+ la.prompt_prefix = buf;
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
}
@@ -1152,7 +1157,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.pooling_type = params.pooling_type;
cparams.attention_type = params.attention_type;
- cparams.defrag_thold = params.defrag_thold;
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.offload_kqv = !params.no_kv_offload;
diff --git a/common/common.h b/common/common.h
index 00c18179a0..3071e7b4ce 100644
--- a/common/common.h
+++ b/common/common.h
@@ -34,6 +34,9 @@ struct common_adapter_lora_info {
std::string path;
float scale;
+ std::string task_name;
+ std::string prompt_prefix;
+
struct llama_adapter_lora * ptr;
};
@@ -288,7 +291,6 @@ struct common_params {
float yarn_beta_fast = 32.0f; // YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
- float defrag_thold = 0.1f; // KV cache defragmentation threshold
// offload params
std::vector devices; // devices to use for offloading
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 42bf10d216..df37c4a6e4 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -72,6 +72,7 @@ class ModelBase:
endianess: gguf.GGUFEndian
use_temp_file: bool
lazy: bool
+ dry_run: bool
part_names: list[str]
is_safetensors: bool
hparams: dict[str, Any]
@@ -111,6 +112,7 @@ class ModelBase:
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
self.use_temp_file = use_temp_file
self.lazy = not eager or (remote_hf_model_id is not None)
+ self.dry_run = dry_run
self.remote_hf_model_id = remote_hf_model_id
if remote_hf_model_id is not None:
self.is_safetensors = True
@@ -1216,6 +1218,55 @@ class TextModel(ModelBase):
raise NotImplementedError("Only MEAN, CLS, and LAST pooling types supported")
self.gguf_writer.add_pooling_type(pooling_type)
+ def _set_vocab_interns1(self):
+ tokens: list[str] = []
+ toktypes: list[int] = []
+
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
+ vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
+ vocab_size = self.hparams.get("vocab_size", len(vocab))
+ assert max(vocab.values()) < vocab_size
+
+ tokpre = self.get_vocab_base_pre(tokenizer)
+
+ reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
+ added_vocab = tokenizer.get_added_vocab()
+
+ added_tokens_decoder = tokenizer.added_tokens_decoder
+
+ for i in range(vocab_size):
+ if i not in reverse_vocab:
+ tokens.append(f"[PAD{i}]")
+ toktypes.append(gguf.TokenType.UNUSED)
+ else:
+ token: str = reverse_vocab[i]
+ if token in added_vocab:
+ # The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
+ # To avoid unexpected issues - we make sure to normalize non-normalized tokens
+ if not added_tokens_decoder[i].normalized:
+ previous_token = token
+ token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
+ if previous_token != token:
+ logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
+
+ if added_tokens_decoder[i].special or self.does_token_look_special(token):
+ toktypes.append(gguf.TokenType.CONTROL)
+ else:
+ toktypes.append(gguf.TokenType.USER_DEFINED)
+ else:
+ toktypes.append(gguf.TokenType.NORMAL)
+ tokens.append(token)
+
+ self.gguf_writer.add_tokenizer_model("gpt2")
+ self.gguf_writer.add_tokenizer_pre(tokpre)
+ self.gguf_writer.add_token_list(tokens)
+ self.gguf_writer.add_token_types(toktypes)
+
+ special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
+ special_vocab._set_special_token("bos", 151643)
+ special_vocab.add_to_gguf(self.gguf_writer)
+
class MmprojModel(ModelBase):
model_type = ModelType.MMPROJ
@@ -2932,7 +2983,8 @@ class Qwen2Model(TextModel):
if "language_model." in name:
name = name.replace("language_model.", "") # for InternVL
if name.startswith("mlp") or name.startswith("multi_modal_projector") \
- or name.startswith("vision_model") or name.startswith("audio_tower"):
+ or name.startswith("vision_model") or name.startswith("audio_tower") \
+ or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):
# skip vision and audio tensors
return []
yield from super().modify_tensors(data_torch, name, bid)
@@ -3109,7 +3161,7 @@ class LLaDAModel(TextModel):
yield from super().modify_tensors(data_torch, name, bid)
-@ModelBase.register("Ernie4_5_ForCausalLM")
+@ModelBase.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM")
class Ernie4_5Model(TextModel):
model_arch = gguf.MODEL_ARCH.ERNIE4_5
@@ -3604,6 +3656,19 @@ class Qwen2MoeModel(TextModel):
class Qwen3Model(Qwen2Model):
model_arch = gguf.MODEL_ARCH.QWEN3
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
+ self.origin_hf_arch = hparams.get('architectures', [None])[0]
+
+ def set_vocab(self):
+ # deal with intern-s1-mini
+ if self.origin_hf_arch == 'InternS1ForConditionalGeneration':
+ self._set_vocab_interns1()
+ return
+
+ super().set_vocab()
+
@ModelBase.register("Qwen3MoeForCausalLM")
class Qwen3MoeModel(Qwen2MoeModel):
@@ -3620,73 +3685,7 @@ class Qwen3MoeModel(Qwen2MoeModel):
self._set_vocab_interns1()
return
- try:
- self._set_vocab_sentencepiece()
- except FileNotFoundError:
- self._set_vocab_gpt2()
-
- def _set_vocab_interns1(self):
- tokens: list[str] = []
- toktypes: list[int] = []
-
- from transformers import AutoTokenizer
- tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
- vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
- vocab_size = self.hparams.get("vocab_size", len(vocab))
- assert max(vocab.values()) < vocab_size
-
- tokpre = self.get_vocab_base_pre(tokenizer)
-
- reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
- added_vocab = tokenizer.get_added_vocab()
-
- added_tokens_decoder = tokenizer.added_tokens_decoder
-
- for i in range(vocab_size):
- if i not in reverse_vocab:
- tokens.append(f"[PAD{i}]")
- toktypes.append(gguf.TokenType.UNUSED)
- else:
- token: str = reverse_vocab[i]
- if token in added_vocab:
- # The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
- # To avoid unexpected issues - we make sure to normalize non-normalized tokens
- if not added_tokens_decoder[i].normalized:
- previous_token = token
- token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
- if previous_token != token:
- logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
-
- if added_tokens_decoder[i].special or self.does_token_look_special(token):
- toktypes.append(gguf.TokenType.CONTROL)
- else:
- toktypes.append(gguf.TokenType.USER_DEFINED)
- else:
- toktypes.append(gguf.TokenType.NORMAL)
- tokens.append(token)
-
- self.gguf_writer.add_tokenizer_model("gpt2")
- self.gguf_writer.add_tokenizer_pre(tokpre)
- self.gguf_writer.add_token_list(tokens)
- self.gguf_writer.add_token_types(toktypes)
-
- special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
- special_tokens_map_file = self.dir_model / 'special_tokens_map.json'
- additional_special_tokens = []
- if special_tokens_map_file.is_file():
- with open(special_tokens_map_file, encoding = 'utf-8') as f:
- additional_special_tokens = json.load(f).get('additional_special_tokens', [])
- tokenizer_cfg_file = self.dir_model / 'special_tokens_map.json'
- if tokenizer_cfg_file.is_file():
- with open(tokenizer_cfg_file, encoding = 'utf-8') as f:
- added_tokens_decoder = json.load(f).get('added_tokens_decoder', {})
- token2ids_map = {data['content'] : int(token) for token, data in added_tokens_decoder.items() if data['special']}
- for token in additional_special_tokens:
- if token in token2ids_map:
- special_vocab._set_special_token(token, token2ids_map[token])
- special_vocab._set_special_token('eos', 151645)
- special_vocab._set_special_token("bos", 151643)
- special_vocab.add_to_gguf(self.gguf_writer)
+ super().set_vocab()
@ModelBase.register("GPT2LMHeadModel")
@@ -4874,11 +4873,35 @@ class NeoBert(BertModel):
@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
class XLMRobertaModel(BertModel):
model_arch = gguf.MODEL_ARCH.BERT
+ _lora_files = {}
+ _lora_names = []
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
+ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
+ hparams = kwargs.pop("hparams", None)
+ if hparams is None:
+ hparams = ModelBase.load_hparams(dir_model, False)
+
+ if lora_names := hparams.get("lora_adaptations"):
+ self._lora_names = lora_names
+ self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3
+
+ super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs)
self._xlmroberta_tokenizer_init()
+ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
+ if self._lora_names:
+ for name in self._lora_names:
+ fname = self.add_prefix_to_filename(self.fname_out, f"lora-{name}-")
+ self._lora_files[name] = gguf.GGUFWriter(fname, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, dry_run=self.dry_run)
+
+ return super().generate_extra_tensors()
+
+ def set_type(self):
+ for lora_writer in self._lora_files.values():
+ lora_writer.add_type(gguf.GGUFType.ADAPTER)
+ lora_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
+ super().set_type()
+
def set_vocab(self):
self._xlmroberta_set_vocab()
@@ -4888,13 +4911,62 @@ class XLMRobertaModel(BertModel):
if name.startswith("roberta."):
name = name[8:]
+ # jina-embeddings-v3
+ if ".parametrizations." in name:
+ name = name.replace(".parametrizations.", ".")
+ if name.endswith(".original"):
+ name = name[:-9]
+
# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
if name == "embeddings.position_embeddings.weight":
if self._position_offset is not None:
data_torch = data_torch[self._position_offset:,:]
+ if name.endswith(".0.lora_A") or name.endswith(".0.lora_B"):
+ if name.startswith("pooler.dense"):
+ return []
+
+ num_loras = data_torch.size(0)
+ assert num_loras == len(self._lora_names)
+
+ # Split out each LoRA in their own GGUF
+ for i, lora_writer in enumerate(self._lora_files.values()):
+ new_name = self.map_tensor_name(name[:-9]) + name[-7:].lower()
+ data = data_torch[i, :, :]
+ # Transpose/flip token_embd/types into correct shape
+ if new_name == "token_embd.weight.lora_b":
+ data = data.T
+ elif new_name.startswith("token_types.weight."):
+ new_name = new_name[:-1] + ("a" if new_name[-1:] == "b" else "b")
+ lora_writer.add_tensor(new_name, data.float().numpy(), raw_dtype=gguf.GGMLQuantizationType.F32)
+
+ return []
+
return super().modify_tensors(data_torch, name, bid)
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+
+ # jina-embeddings-v3
+ if rotary_emb_base := self.hparams.get("rotary_emb_base"):
+ self.gguf_writer.add_rope_freq_base(rotary_emb_base)
+ lora_alpha = self.hparams.get("lora_alpha")
+ if lora_prompt_prefixes := self.hparams.get("task_instructions"):
+ assert self._lora_files and all(lora_name in lora_prompt_prefixes for lora_name in self._lora_files.keys())
+ for lora_name, lora_writer in self._lora_files.items():
+ lora_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, lora_alpha if lora_alpha is not None else 1.0)
+ lora_writer.add_string(gguf.Keys.Adapter.LORA_TASK_NAME, lora_name)
+ if lora_prompt_prefixes:
+ lora_writer.add_string(gguf.Keys.Adapter.LORA_PROMPT_PREFIX, lora_prompt_prefixes[lora_name])
+
+ def write(self):
+ super().write()
+ for lora_writer in self._lora_files.values():
+ lora_writer.write_header_to_file()
+ lora_writer.write_kv_data_to_file()
+ lora_writer.write_tensors_to_file(progress=True)
+ lora_writer.close()
+
@ModelBase.register("GemmaForCausalLM")
class GemmaModel(TextModel):
@@ -5854,6 +5926,11 @@ class OlmoModel(TextModel):
return [(self.map_tensor_name(name), data_torch)]
+@ModelBase.register("SeedOssForCausalLM")
+class SeedOssModel(TextModel):
+ model_arch = gguf.MODEL_ARCH.SEED_OSS
+
+
@ModelBase.register("Olmo2ForCausalLM")
class Olmo2Model(TextModel):
model_arch = gguf.MODEL_ARCH.OLMO2
@@ -6252,9 +6329,11 @@ class DeepseekModel(TextModel):
raise ValueError(f"Unprocessed experts: {experts}")
-@ModelBase.register("DeepseekV2ForCausalLM")
-@ModelBase.register("DeepseekV3ForCausalLM")
-@ModelBase.register("KimiVLForConditionalGeneration")
+@ModelBase.register(
+ "DeepseekV2ForCausalLM",
+ "DeepseekV3ForCausalLM",
+ "KimiVLForConditionalGeneration",
+)
class DeepseekV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
@@ -7467,9 +7546,13 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
]
# n_group and d_inner are used during reshape_tensors for mamba2
- self.d_model = self.find_hparam(["hidden_size", "d_model"])
- self.n_group = self.find_hparam(["n_groups"])
- self.d_inner = self.find_hparam(["expand"]) * self.d_model
+ # NOTE: Explicitly include hparam prefix prefix for d_model to
+ # disambiguate with top-level head_dim
+ # NOTE 2: If needed for future models, this can be isolated in a method
+ # to separate the prefix setting and teh keys used
+ self.d_model = self.find_hparam([f"{self.hparam_prefixes[0]}_head_dim", "hidden_size", "d_model"])
+ self.n_group = self.find_hparam(["n_groups", "num_groups"])
+ self.d_inner = self.find_hparam(["expand", "num_heads"]) * self.d_model
def get_attn_layers(self):
# Explicit list of layer type names
@@ -7530,12 +7613,12 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
## Mamba mixer params ##
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
- self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
+ self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state", "state_dim", "ssm_state_size"]))
self.gguf_writer.add_ssm_group_count(self.n_group)
self.gguf_writer.add_ssm_inner_size(self.d_inner)
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
# in llama.cpp
- self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
+ self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads", "num_heads"]))
## Attention params ##
head_count_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
@@ -7562,6 +7645,55 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
Mamba2Model.set_vocab(self)
+@ModelBase.register("NemotronHForCausalLM")
+class NemotronHModel(GraniteHybridModel):
+ """Hybrid mamba2/attention model from NVIDIA"""
+ model_arch = gguf.MODEL_ARCH.NEMOTRON_H
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # Save the top-level head_dim for later
+ self.head_dim = self.hparams.get("head_dim", self.hparams.get("attention_head_dim"))
+ assert self.head_dim is not None, "Could not find the attention head dim in config"
+
+ # Don't use expand to calculate d_inner
+ self.d_inner = self.find_hparam(["num_heads"]) * self.d_model
+
+ # Update the ssm / attn / mlp layers
+ # M: Mamba2, *: Attention, -: MLP
+ hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
+ self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"]
+ self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "-"]
+
+ def get_attn_layers(self):
+ hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
+ assert len(hybrid_override_pattern) == self.block_count, "Mismatch between hybrid override and num_hidden_layers!"
+ return [i for i, val in enumerate(hybrid_override_pattern) if val == "*"]
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+
+ self.gguf_writer.add_key_length(self.head_dim)
+ self.gguf_writer.add_value_length(self.head_dim)
+
+ # Set feed_forward_length
+ # NOTE: This will trigger an override warning. This is preferrable to
+ # duplicating all the parent logic
+ n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"])
+ self.gguf_writer.add_feed_forward_length([
+ n_ff if i in self._mlp_layers else 0 for i in range(self.block_count)
+ ])
+
+ def set_vocab(self):
+ super().set_vocab()
+
+ # The tokenizer _does_ add a BOS token (via post_processor type
+ # TemplateProcessing) but does not set add_bos_token to true in the
+ # config, so we need to explicitly override it here.
+ self.gguf_writer.add_add_bos_token(True)
+
+
@ModelBase.register("BailingMoeForCausalLM")
class BailingMoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.BAILINGMOE
@@ -8505,6 +8637,43 @@ class PixtralModel(LlavaVisionModel):
return "mm.2.weight"
return super().map_tensor_name(name, try_suffixes)
+
+@ModelBase.register("KimiVLForConditionalGeneration")
+class KimiVLModel(MmprojModel):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ assert self.hparams_vision is not None
+ self.hparams_vision["image_size"] = 64 * 14 # for compatibility
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIVL)
+ self.gguf_writer.add_vision_use_gelu(True)
+ self.gguf_writer.add_vision_projector_scale_factor(2)
+ # eps is the same as pytorch's default value
+ assert self.hparams_vision is not None
+ self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-5))
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ del bid # unused
+ is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
+
+ if is_vision_tensor:
+ if "pos_emb.weight" in name:
+ data_torch = data_torch.view(data_torch.shape[0] * data_torch.shape[1], data_torch.shape[2])
+ elif "wqkv" in name:
+ split_dim = 0 if "weight" in name else -1
+ wq, wk, wv = data_torch.chunk(3, dim=split_dim)
+ return [
+ (self.map_tensor_name(name.replace("wqkv", "wq")), wq),
+ (self.map_tensor_name(name.replace("wqkv", "wk")), wk),
+ (self.map_tensor_name(name.replace("wqkv", "wv")), wv)
+ ]
+
+ return [(self.map_tensor_name(name), data_torch)]
+
+ return [] # skip other tensors
+
###### CONVERSION LOGIC ######
diff --git a/docs/build-s390x.md b/docs/build-s390x.md
index b36a199814..f3cdd63be3 100644
--- a/docs/build-s390x.md
+++ b/docs/build-s390x.md
@@ -265,8 +265,9 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
| BF16 | 🚫 | 🚫 | ❓ | ❓ |
| Q4_0 | ✅ | ✅ | ❓ | ❓ |
| Q4_1 | ✅ | ✅ | ❓ | ❓ |
-| Q5_0 | 🚫 | 🚫 | ❓ | ❓ |
-| Q5_1 | 🚫 | 🚫 | ❓ | ❓ |
+| MXFP4 | 🚫 | 🚫 | ❓ | ❓ |
+| Q5_0 | ✅ | ✅ | ❓ | ❓ |
+| Q5_1 | ✅ | ✅ | ❓ | ❓ |
| Q8_0 | ✅ | ✅ | ❓ | ❓ |
| Q2_K | 🚫 | 🚫 | ❓ | ❓ |
| Q3_K | ✅ | ✅ | ❓ | ❓ |
@@ -291,4 +292,4 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
- 🚫 - acceleration unavailable, will still run using scalar implementation
- ❓ - acceleration unknown, please contribute if you can test it yourself
-Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on July 31, 2025.
+Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Aug 22, 2025.
diff --git a/docs/function-calling.md b/docs/function-calling.md
index 37eacaf310..67cf785c7a 100644
--- a/docs/function-calling.md
+++ b/docs/function-calling.md
@@ -21,6 +21,8 @@ Function calling is supported for all models (see https://github.com/ggml-org/ll
- Use `--chat-template-file` to override the template when appropriate (see examples below)
- Generic support may consume more tokens and be less efficient than a model's native format.
+- Multiple/parallel tool calling is supported on some models but disabled by default, enable it by passing `"parallel_tool_calls": true` in the completion endpoint payload.
+
Show some common templates and which format handler they use
diff --git a/docs/multimodal/minicpmv4.0.md b/docs/multimodal/minicpmv4.0.md
index 65887d9601..d04cb338ce 100644
--- a/docs/multimodal/minicpmv4.0.md
+++ b/docs/multimodal/minicpmv4.0.md
@@ -6,7 +6,7 @@ Download [MiniCPM-V-4](https://huggingface.co/openbmb/MiniCPM-V-4) PyTorch model
### Build llama.cpp
-Readme modification time: 20250206
+Readme modification time: 20250731
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md)
diff --git a/docs/multimodal/minicpmv4.5.md b/docs/multimodal/minicpmv4.5.md
new file mode 100644
index 0000000000..8fea5e611d
--- /dev/null
+++ b/docs/multimodal/minicpmv4.5.md
@@ -0,0 +1,47 @@
+## MiniCPM-V 4.5
+
+### Prepare models and code
+
+Download [MiniCPM-V-4_5](https://huggingface.co/openbmb/MiniCPM-V-4_5) PyTorch model from huggingface to "MiniCPM-V-4_5" folder.
+
+
+### Build llama.cpp
+Readme modification time: 20250826
+
+If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md)
+
+Clone llama.cpp:
+```bash
+git clone https://github.com/ggerganov/llama.cpp
+cd llama.cpp
+```
+
+Build llama.cpp using `CMake`:
+```bash
+cmake -B build
+cmake --build build --config Release
+```
+
+
+### Usage of MiniCPM-V 4
+
+Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-V-4_5-gguf) by us)
+
+```bash
+python ./tools/mtmd/legacy-models/minicpmv-surgery.py -m ../MiniCPM-V-4_5
+python ./tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-V-4_5 --minicpmv-projector ../MiniCPM-V-4_5/minicpmv.projector --output-dir ../MiniCPM-V-4_5/ --minicpmv_version 6
+python ./convert_hf_to_gguf.py ../MiniCPM-V-4_5/model
+
+# quantize int4 version
+./build/bin/llama-quantize ../MiniCPM-V-4_5/model/ggml-model-f16.gguf ../MiniCPM-V-4_5/model/ggml-model-Q4_K_M.gguf Q4_K_M
+```
+
+
+Inference on Linux or Mac
+```bash
+# run in single-turn mode
+./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4_5/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-4_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
+
+# run in conversation mode
+./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-4_5/mmproj-model-f16.gguf
+```
diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp
index 61eefc7248..d4ef751fbb 100644
--- a/examples/eval-callback/eval-callback.cpp
+++ b/examples/eval-callback/eval-callback.cpp
@@ -28,9 +28,40 @@ static std::string ggml_ne_string(const ggml_tensor * t) {
return str;
}
+static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
+ size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
+ float v;
+ if (type == GGML_TYPE_F16) {
+ v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
+ } else if (type == GGML_TYPE_F32) {
+ v = *(float *) &data[i];
+ } else if (type == GGML_TYPE_I64) {
+ v = (float) *(int64_t *) &data[i];
+ } else if (type == GGML_TYPE_I32) {
+ v = (float) *(int32_t *) &data[i];
+ } else if (type == GGML_TYPE_I16) {
+ v = (float) *(int16_t *) &data[i];
+ } else if (type == GGML_TYPE_I8) {
+ v = (float) *(int8_t *) &data[i];
+ } else {
+ GGML_ABORT("fatal error");
+ }
+ return v;
+}
+
static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) {
GGML_ASSERT(n > 0);
float sum = 0;
+ for (int64_t i3 = 0; i3 < ne[3]; i3++) {
+ for (int64_t i2 = 0; i2 < ne[2]; i2++) {
+ for (int64_t i1 = 0; i1 < ne[1]; i1++) {
+ for (int64_t i0 = 0; i0 < ne[0]; i0++) {
+ const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3);
+ sum += v;
+ }
+ }
+ }
+ }
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
LOG(" [\n");
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
@@ -50,25 +81,8 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
LOG("..., ");
i0 = ne[0] - n;
}
- size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
- float v;
- if (type == GGML_TYPE_F16) {
- v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
- } else if (type == GGML_TYPE_F32) {
- v = *(float *) &data[i];
- } else if (type == GGML_TYPE_I64) {
- v = (float) *(int64_t *) &data[i];
- } else if (type == GGML_TYPE_I32) {
- v = (float) *(int32_t *) &data[i];
- } else if (type == GGML_TYPE_I16) {
- v = (float) *(int16_t *) &data[i];
- } else if (type == GGML_TYPE_I8) {
- v = (float) *(int8_t *) &data[i];
- } else {
- GGML_ABORT("fatal error");
- }
+ const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3);
LOG("%12.4f", v);
- sum += v;
if (i0 < ne[0] - 1) LOG(", ");
}
LOG("],\n");
diff --git a/examples/llama.vim b/examples/llama.vim
index af3fd3935d..736802d365 100644
--- a/examples/llama.vim
+++ b/examples/llama.vim
@@ -17,7 +17,7 @@
"
" start the llama.cpp server with a FIM-compatible model. for example:
"
-" $ llama-server -m {model.gguf} --port 8012 -ngl 99 -fa -dt 0.1 --ubatch-size 512 --batch-size 1024 --cache-reuse 256
+" $ llama-server -m {model.gguf} --port 8012 -ngl 99 -fa --ubatch-size 512 --batch-size 1024 --cache-reuse 256
"
" --batch-size [512, model max context]
"
diff --git a/examples/model-conversion/Makefile b/examples/model-conversion/Makefile
index 27d95b4f2b..03b928afda 100644
--- a/examples/model-conversion/Makefile
+++ b/examples/model-conversion/Makefile
@@ -1,4 +1,5 @@
-# Validation functions
+MAKEFLAGS += --no-print-directory
+
define validate_model_path
@if [ -z "$(MODEL_PATH)" ]; then \
echo "Error: MODEL_PATH must be provided either as:"; \
@@ -17,6 +18,13 @@ define validate_embedding_model_path
fi
endef
+define quantize_model
+ @CONVERTED_MODEL="$(1)" QUANTIZED_TYPE="$(QUANTIZED_TYPE)" \
+ TOKEN_EMBD_TYPE="$(TOKEN_EMBD_TYPE)" OUTPUT_TYPE="$(OUTPUT_TYPE)" \
+ ./scripts/utils/quantize.sh "$(1)" "$(QUANTIZED_TYPE)" "$(TOKEN_EMBD_TYPE)" "$(OUTPUT_TYPE)"
+ @echo "Export the quantized model path to $(2) variable in your environment"
+endef
+
###
### Casual Model targets/recipes
###
@@ -29,6 +37,20 @@ causal-convert-model:
METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
./scripts/causal/convert-model.sh
+causal-convert-mm-model-bf16: OUTTYPE=bf16
+causal-convert-mm-model-bf16: MM_OUTTYPE=f16
+causal-convert-mm-model-bf16: causal-convert-mm-model
+
+causal-convert-mm-model:
+ $(call validate_model_path,causal-convert-mm-model)
+ @MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(MODEL_PATH)" \
+ METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
+ ./scripts/causal/convert-model.sh
+
+ @MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(MM_OUTTYPE)" MODEL_PATH="$(MODEL_PATH)" \
+ METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
+ ./scripts/causal/convert-model.sh --mmproj
+
causal-run-original-model:
$(call validate_model_path,causal-run-original-model)
@MODEL_PATH="$(MODEL_PATH)" ./scripts/causal/run-org-model.py
@@ -67,9 +89,15 @@ causal-quantize-Q8_0: causal-quantize-model
causal-quantize-Q4_0: QUANTIZED_TYPE = Q4_0
causal-quantize-Q4_0: causal-quantize-model
+# For Quantization Aware Trained (QAT) models in Q4_0 we explicitly set the
+# token embedding and output types to Q8_0 instead of the default Q6_K.
+causal-quantize-qat-Q4_0: QUANTIZED_TYPE = Q4_0
+causal-quantize-qat-Q4_0: TOKEN_EMBD_TYPE = Q8_0
+causal-quantize-qat-Q4_0: OUTPUT_TYPE = Q8_0
+causal-quantize-qat-Q4_0: causal-quantize-model
+
causal-quantize-model:
- @CONVERTED_MODEL="$(CONVERTED_MODEL)" QUANTIZED_TYPE="$(QUANTIZED_TYPE)" ./scripts/utils/quantize.sh ${CONVERTED_MODEL} ${QUANTIZED_TYPE}
- @echo "Export the quantized model path to QUANTIZED_MODEL variable in your environment"
+ $(call quantize_model,$(CONVERTED_MODEL),QUANTIZED_MODEL)
causal-run-quantized-model:
@QUANTIZED_MODEL="$(QUANTIZED_MODEL)" ./scripts/causal/run-converted-model.sh ${QUANTIZED_MODEL}
@@ -117,9 +145,15 @@ embedding-quantize-Q8_0: embedding-quantize-model
embedding-quantize-Q4_0: QUANTIZED_TYPE = Q4_0
embedding-quantize-Q4_0: embedding-quantize-model
+# For Quantization Aware Trained (QAT) models in Q4_0 we explicitly set the
+# token embedding and output types to Q8_0 instead of the default Q6_K.
+embedding-quantize-qat-Q4_0: QUANTIZED_TYPE = Q4_0
+embedding-quantize-qat-Q4_0: TOKEN_EMBD_TYPE = Q8_0
+embedding-quantize-qat-Q4_0: OUTPUT_TYPE = Q8_0
+embedding-quantize-qat-Q4_0: embedding-quantize-model
+
embedding-quantize-model:
- @./scripts/utils/quantize.sh ${CONVERTED_EMBEDDING_MODEL} ${QUANTIZED_TYPE}
- @echo "Export the quantized model path to QUANTIZED_EMBEDDING_MODEL variable in your environment"
+ $(call quantize_model,$(CONVERTED_EMBEDDING_MODEL),QUANTIZED_EMBEDDING_MODEL)
embedding-run-quantized-model:
@./scripts/embedding/run-converted-model.sh ${QUANTIZED_EMBEDDING_MODEL}
@@ -144,6 +178,15 @@ perplexity-run:
hf-create-model:
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}"
+hf-create-model-dry-run:
+ @./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -d
+
+hf-create-model-embedding:
+ @./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -e
+
+hf-create-model-embedding-dry-run:
+ @./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -e -d
+
hf-create-model-private:
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -p
diff --git a/examples/model-conversion/README.md b/examples/model-conversion/README.md
index c924a6be3c..5e5992d964 100644
--- a/examples/model-conversion/README.md
+++ b/examples/model-conversion/README.md
@@ -137,6 +137,18 @@ Then the quantized model can be run using the following command:
(venv) $ make causal-run-quantized-model
```
+### Quantizing QAT (Quantization Aware Training) models
+When quantizing to `Q4_0`, the default data type for the token embedding weights
+will be `Q6_K`. For models that are going to be uploaded to ggml-org it is
+recommended to use `Q8_0` instead for the embeddings and output tensors.
+The reason is that although `Q6_K` is smaller in size, it requires more compute
+to unpack, which can hurt performance during output generation when the entire
+embedding matrix must be dequantized to compute vocabulary logits. `Q8_0`
+provides practically full quality with better computational efficiency.
+```console
+(venv) $ make causal-quantize-qat-Q4_0
+```
+
## Embedding Language Model Conversion
@@ -238,6 +250,18 @@ Then the quantized model can be run using the following command:
(venv) $ make embedding-run-quantized-model
```
+### Quantizing QAT (Quantization Aware Training) models
+When quantizing to `Q4_0`, the default data type for the token embedding weights
+will be `Q6_K`. For models that are going to be uploaded to ggml-org it is
+recommended to use `Q8_0` instead for the embeddings and output tensors.
+The reason is that although `Q6_K` is smaller in size, it requires more compute
+to unpack, which can hurt performance during output generation when the entire
+embedding matrix must be dequantized to compute vocabulary logits. `Q8_0`
+provides practically full quality with better computational efficiency.
+```console
+(venv) $ make embedding-quantize-qat-Q4_0
+```
+
## Perplexity Evaluation
### Simple perplexity evaluation
@@ -285,13 +309,21 @@ For the following targets a `HF_TOKEN` environment variable is required.
This will create a new model repsository on Hugging Face with the specified
model name.
```console
-(venv) $ make hf-create-model MODEL_NAME='TestModel' NAMESPACE="danbev"
+(venv) $ make hf-create-model MODEL_NAME='TestModel' NAMESPACE="danbev" ORIGINAL_BASE_MODEL="some-base-model"
Repository ID: danbev/TestModel-GGUF
Repository created: https://huggingface.co/danbev/TestModel-GGUF
```
Note that we append a `-GGUF` suffix to the model name to ensure a consistent
naming convention for GGUF models.
+An embedding model can be created using the following command:
+```console
+(venv) $ make hf-create-model-embedding MODEL_NAME='TestEmbeddingModel' NAMESPACE="danbev" ORIGINAL_BASE_MODEL="some-base-model"
+```
+The only difference is that the model card for an embedding model will be different
+with regards to the llama-server command and also how to access/call the embedding
+endpoint.
+
### Upload a GGUF model to model repository
The following target uploads a model to an existing Hugging Face model repository.
```console
diff --git a/examples/model-conversion/logits.cpp b/examples/model-conversion/logits.cpp
index 2cac6a3b3e..ddc5e9005f 100644
--- a/examples/model-conversion/logits.cpp
+++ b/examples/model-conversion/logits.cpp
@@ -112,6 +112,7 @@ int main(int argc, char ** argv) {
ctx_params.no_perf = false;
if (embedding_mode) {
ctx_params.embeddings = true;
+ ctx_params.pooling_type = LLAMA_POOLING_TYPE_NONE;
ctx_params.n_ubatch = ctx_params.n_batch;
}
diff --git a/examples/model-conversion/scripts/causal/convert-model.sh b/examples/model-conversion/scripts/causal/convert-model.sh
index 56b21f9baa..9d95025950 100755
--- a/examples/model-conversion/scripts/causal/convert-model.sh
+++ b/examples/model-conversion/scripts/causal/convert-model.sh
@@ -1,5 +1,21 @@
#!/bin/bash
+set -e
+
+# Parse command line arguments
+MMPROJ=""
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ --mmproj)
+ MMPROJ="--mmproj"
+ shift
+ ;;
+ *)
+ shift
+ ;;
+ esac
+done
+
MODEL_NAME="${MODEL_NAME:-$(basename "$MODEL_PATH")}"
OUTPUT_DIR="${OUTPUT_DIR:-../../models}"
TYPE="${OUTTYPE:-f16}"
@@ -11,12 +27,20 @@ echo "Model name: ${MODEL_NAME}"
echo "Data type: ${TYPE}"
echo "Converted model path:: ${CONVERTED_MODEL}"
echo "Metadata override: ${METADATA_OVERRIDE}"
-python ../../convert_hf_to_gguf.py --verbose \
- ${MODEL_PATH} \
- --outfile ${CONVERTED_MODEL} \
- --outtype ${TYPE} \
- --metadata "${METADATA_OVERRIDE}"
+
+CMD_ARGS=("python" "../../convert_hf_to_gguf.py" "--verbose")
+CMD_ARGS+=("${MODEL_PATH}")
+CMD_ARGS+=("--outfile" "${CONVERTED_MODEL}")
+CMD_ARGS+=("--outtype" "${TYPE}")
+[[ -n "$METADATA_OVERRIDE" ]] && CMD_ARGS+=("--metadata" "${METADATA_OVERRIDE}")
+[[ -n "$MMPROJ" ]] && CMD_ARGS+=("${MMPROJ}")
+
+"${CMD_ARGS[@]}"
echo ""
echo "The environment variable CONVERTED_MODEL can be set to this path using:"
echo "export CONVERTED_MODEL=$(realpath ${CONVERTED_MODEL})"
+if [[ -n "$MMPROJ" ]]; then
+ mmproj_file="${OUTPUT_DIR}/mmproj-$(basename "${CONVERTED_MODEL}")"
+ echo "The mmproj model was created in $(realpath "$mmproj_file")"
+fi
diff --git a/examples/model-conversion/scripts/readme.md.template b/examples/model-conversion/scripts/causal/modelcard.template
similarity index 100%
rename from examples/model-conversion/scripts/readme.md.template
rename to examples/model-conversion/scripts/causal/modelcard.template
diff --git a/examples/model-conversion/scripts/embedding/modelcard.template b/examples/model-conversion/scripts/embedding/modelcard.template
new file mode 100644
index 0000000000..75c580524f
--- /dev/null
+++ b/examples/model-conversion/scripts/embedding/modelcard.template
@@ -0,0 +1,48 @@
+---
+base_model:
+- {base_model}
+---
+# {model_name} GGUF
+
+Recommended way to run this model:
+
+```sh
+llama-server -hf {namespace}/{model_name}-GGUF
+```
+
+Then the endpoint can be accessed at http://localhost:8080/embedding, for
+example using `curl`:
+```console
+curl --request POST \
+ --url http://localhost:8080/embedding \
+ --header "Content-Type: application/json" \
+ --data '{{"input": "Hello embeddings"}}' \
+ --silent
+```
+
+Alternatively, the `llama-embedding` command line tool can be used:
+```sh
+llama-embedding -hf {namespace}/{model_name}-GGUF --verbose-prompt -p "Hello embeddings"
+```
+
+#### embd_normalize
+When a model uses pooling, or the pooling method is specified using `--pooling`,
+the normalization can be controlled by the `embd_normalize` parameter.
+
+The default value is `2` which means that the embeddings are normalized using
+the Euclidean norm (L2). Other options are:
+* -1 No normalization
+* 0 Max absolute
+* 1 Taxicab
+* 2 Euclidean/L2
+* \>2 P-Norm
+
+This can be passed in the request body to `llama-server`, for example:
+```sh
+ --data '{{"input": "Hello embeddings", "embd_normalize": -1}}' \
+```
+
+And for `llama-embedding`, by passing `--embd-normalize `, for example:
+```sh
+llama-embedding -hf {namespace}/{model_name}-GGUF --embd-normalize -1 -p "Hello embeddings"
+```
diff --git a/examples/model-conversion/scripts/utils/hf-create-model.py b/examples/model-conversion/scripts/utils/hf-create-model.py
index 09bb8511ef..ea99bd886f 100755
--- a/examples/model-conversion/scripts/utils/hf-create-model.py
+++ b/examples/model-conversion/scripts/utils/hf-create-model.py
@@ -26,21 +26,31 @@ parser.add_argument('--namespace', '-ns', help='Namespace to add the model to',
parser.add_argument('--org-base-model', '-b', help='Original Base model name', default="")
parser.add_argument('--no-card', action='store_true', help='Skip creating model card')
parser.add_argument('--private', '-p', action='store_true', help='Create private model')
+parser.add_argument('--embedding', '-e', action='store_true', help='Use embedding model card template')
+parser.add_argument('--dry-run', '-d', action='store_true', help='Print repository info and template without creating repository')
args = parser.parse_args()
repo_id = f"{args.namespace}/{args.model_name}-GGUF"
print("Repository ID: ", repo_id)
-repo_url = api.create_repo(
- repo_id=repo_id,
- repo_type="model",
- private=args.private,
- exist_ok=False
-)
+repo_url = None
+if not args.dry_run:
+ repo_url = api.create_repo(
+ repo_id=repo_id,
+ repo_type="model",
+ private=args.private,
+ exist_ok=False
+ )
if not args.no_card:
- template_path = "scripts/readme.md.template"
+ if args.embedding:
+ template_path = "scripts/embedding/modelcard.template"
+ else:
+ template_path = "scripts/causal/modelcard.template"
+
+ print("Template path: ", template_path)
+
model_card_content = load_template_and_substitute(
template_path,
model_name=args.model_name,
@@ -48,16 +58,21 @@ if not args.no_card:
base_model=args.org_base_model,
)
- if model_card_content:
- api.upload_file(
- path_or_fileobj=model_card_content.encode('utf-8'),
- path_in_repo="README.md",
- repo_id=repo_id
- )
- print("Model card created successfully.")
+ if args.dry_run:
+ print("\nTemplate Content:\n")
+ print(model_card_content)
else:
- print("Failed to create model card.")
+ if model_card_content:
+ api.upload_file(
+ path_or_fileobj=model_card_content.encode('utf-8'),
+ path_in_repo="README.md",
+ repo_id=repo_id
+ )
+ print("Model card created successfully.")
+ else:
+ print("Failed to create model card.")
-print(f"Repository created: {repo_url}")
+if not args.dry_run and repo_url:
+ print(f"Repository created: {repo_url}")
diff --git a/examples/model-conversion/scripts/utils/quantize.sh b/examples/model-conversion/scripts/utils/quantize.sh
index bcb8775754..90460aa6b0 100755
--- a/examples/model-conversion/scripts/utils/quantize.sh
+++ b/examples/model-conversion/scripts/utils/quantize.sh
@@ -4,6 +4,8 @@ set -e
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
QUANTIZED_TYPE="${2:-"$QUANTIZED_TYPE"}"
+TOKEN_EMBD_TYPE="${3:-"${TOKEN_EMBD_TYPE}"}"
+OUTPUT_TYPE="${4:-"${OUTPUT_TYPE}"}"
QUANTIZED_MODEL=$CONVERTED_MODEL
# Final check if we have a model path
@@ -14,6 +16,11 @@ if [ -z "$CONVERTED_MODEL" ]; then
exit 1
fi
+if [ -z "$QUANTIZED_TYPE" ]; then
+ echo "Error: QUANTIZED_TYPE is required" >&2
+ exit 1
+fi
+
echo $CONVERTED_MODEL
# Process the quantized model filename
@@ -26,9 +33,16 @@ else
exit 1
fi
-
cmake --build ../../build --target llama-quantize -j8
-../../build/bin/llama-quantize $CONVERTED_MODEL $QUANTIZED_MODEL $QUANTIZED_TYPE
+echo $TOKEN_EMBD_TYPE
+echo $OUTPUT_TYPE
+
+CMD_ARGS=("../../build/bin/llama-quantize")
+[[ -n "$TOKEN_EMBD_TYPE" ]] && CMD_ARGS+=("--token-embedding-type" "$TOKEN_EMBD_TYPE")
+[[ -n "$OUTPUT_TYPE" ]] && CMD_ARGS+=("--output-tensor-type" "$OUTPUT_TYPE")
+CMD_ARGS+=("$CONVERTED_MODEL" "$QUANTIZED_MODEL" "$QUANTIZED_TYPE")
+
+"${CMD_ARGS[@]}"
echo "Quantized model saved to: $QUANTIZED_MODEL"
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index b8b82e11c8..7e9c3c8c7a 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -512,6 +512,7 @@ extern "C" {
GGML_OP_IM2COL,
GGML_OP_IM2COL_BACK,
GGML_OP_CONV_2D,
+ GGML_OP_CONV_3D,
GGML_OP_CONV_2D_DW,
GGML_OP_CONV_TRANSPOSE_2D,
GGML_OP_POOL_1D,
@@ -1940,6 +1941,23 @@ extern "C" {
int d0, // dilation dimension 0
int d1); // dilation dimension 1
+ GGML_API struct ggml_tensor * ggml_conv_3d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC]
+ struct ggml_tensor * b, // input [W, H, D, C * N]
+ int s0, // stride
+ int s1,
+ int s2,
+ int p0, // padding
+ int p1,
+ int p2,
+ int d0, // dilation
+ int d1,
+ int d2,
+ int n_channels,
+ int n_batch,
+ int n_channels_out);
+
enum ggml_op_pool {
GGML_OP_POOL_MAX,
GGML_OP_POOL_AVG,
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
index c1e58fbb64..e34feccc98 100644
--- a/ggml/src/ggml-backend.cpp
+++ b/ggml/src/ggml-backend.cpp
@@ -1355,15 +1355,15 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
std::vector ids;
std::vector used_ids;
- for (int i = 0; i < sched->n_splits; i++) {
- struct ggml_backend_sched_split * split = &splits[i];
+ for (int split_id = 0; split_id < sched->n_splits; split_id++) {
+ struct ggml_backend_sched_split * split = &splits[split_id];
int split_backend_id = split->backend_id;
ggml_backend_t split_backend = sched->backends[split_backend_id];
// copy the input tensors to the split backend
- for (int j = 0; j < split->n_inputs; j++) {
- ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]);
- struct ggml_tensor * input = split->inputs[j];
+ for (int input_id = 0; input_id < split->n_inputs; input_id++) {
+ ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]);
+ struct ggml_tensor * input = split->inputs[input_id];
struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy);
if (input->flags & GGML_TENSOR_FLAG_INPUT) {
@@ -1398,10 +1398,22 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
// get the ids
ggml_tensor * ids_tensor = node->src[2];
+ ggml_backend_t ids_backend = split_backend;
+
+ // if the ids tensor is also an input of the split, it may not have been copied yet to the split backend
+ // in that case, we use the original ids tensor
+ for (int i = input_id + 1; i < split->n_inputs; i++) {
+ if (ids_tensor == tensor_copy(split->inputs[i], split_backend_id, sched->cur_copy)) {
+ ids_tensor = split->inputs[i];
+ ids_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[i]);
+ break;
+ }
+ }
+
if (ids_tensor != prev_ids_tensor) {
ids.resize(ggml_nbytes(ids_tensor) / sizeof(int32_t));
- ggml_backend_tensor_get_async(split_backend, ids_tensor, ids.data(), 0, ggml_nbytes(ids_tensor));
- ggml_backend_synchronize(split_backend);
+ ggml_backend_tensor_get_async(ids_backend, ids_tensor, ids.data(), 0, ggml_nbytes(ids_tensor));
+ ggml_backend_synchronize(ids_backend);
// find the used experts
used_ids.clear();
@@ -1409,6 +1421,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
for (int64_t i1 = 0; i1 < ids_tensor->ne[1]; i1++) {
for (int64_t i0 = 0; i0 < ids_tensor->ne[0]; i0++) {
int32_t id = ids[i1 * ids_tensor->nb[1]/sizeof(int32_t) + i0 * ids_tensor->nb[0]/sizeof(int32_t)];
+ GGML_ASSERT(id >= 0 && id < n_expert);
ggml_bitset_set(used_ids.data(), id);
}
}
diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp
index 2a5cb8abfa..c42871c575 100755
--- a/ggml/src/ggml-cann/aclnn_ops.cpp
+++ b/ggml/src/ggml-cann/aclnn_ops.cpp
@@ -867,6 +867,86 @@ static aclTensor* aclnn_values(ggml_backend_cann_context& ctx, void* buffer,
return acl_tensor;
}
+/**
+ * @brief Fills a tensor with a scalar value.
+ *
+ * This function fills the destination tensor `acl_dst` with the scalar value
+ * `scalar`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param scalar The scalar value used to fill the tensor.
+ * @param acl_dst The destination tensor to be filled with the scalar value.
+ */
+static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
+ aclTensor* acl_dst) {
+ auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT);
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst, acl_scalar);
+ ggml_cann_release_resources(ctx, acl_scalar);
+}
+
+/**
+ * @brief Get or expand a cached float32 tensor filled with a scalar value.
+ *
+ * This function manages cached device memory for float32 tensors. If the current
+ * cache size is insufficient for the requested tensor shape, the old memory will
+ * be released and new memory will be allocated. The allocated buffer is then
+ * initialized either with zeros (when @p value == 0.0f) or with the given scalar
+ * value using CANN operations. Finally, an aclTensor object is created from the
+ * cached memory and returned.
+ *
+ * @param ctx The CANN backend context that manages device memory.
+ * @param buffer A pointer to the cached device buffer (will be allocated
+ * or reallocated if necessary).
+ * @param cache_element The current number of cached elements. This will be
+ * updated when the cache is expanded.
+ * @param ne The tensor shape array (number of elements in each dimension).
+ * @param nb The stride size for each dimension.
+ * @param dims The number of tensor dimensions.
+ * @param value The scalar value used to fill the tensor (supports zero
+ * initialization via memset or arbitrary values via fill_scalar).
+ * @return An aclTensor pointer created from the cached buffer.
+ */
+static aclTensor* get_f32_cache_acl_tensor(
+ ggml_backend_cann_context& ctx,
+ void** buffer,
+ int64_t &cache_element,
+ int64_t* ne,
+ size_t* nb,
+ int64_t dims,
+ float value) {
+ // Calculate total number of elements
+ int64_t n_element = 1;
+ for (int i = 0; i < dims; i++) {
+ n_element *= ne[i];
+ }
+ size_t size = n_element * sizeof(float);
+
+ // Allocate or expand cache if needed
+ if (cache_element < n_element) {
+ if (*buffer != nullptr) {
+ aclrtFree(*buffer);
+ *buffer = nullptr;
+ }
+
+ ACL_CHECK(aclrtMalloc(buffer, size, ACL_MEM_MALLOC_HUGE_FIRST));
+ cache_element = n_element;
+
+ // Initialize cache
+ if (value == 0.0f) {
+ ACL_CHECK(aclrtMemsetAsync(*buffer, size, 0, size, ctx.stream()));
+ } else {
+ int64_t pool_ne[1] = { n_element };
+ size_t pool_nb[1] = { sizeof(float) };
+ aclTensor* acl_value = ggml_cann_create_tensor(
+ *buffer, ACL_FLOAT, sizeof(float), pool_ne, pool_nb, 1);
+ aclnn_fill_scalar(ctx, 1, acl_value);
+ ggml_cann_release_resources(ctx, acl_value);
+ }
+ }
+
+ return ggml_cann_create_tensor(*buffer, ACL_FLOAT, sizeof(float), ne, nb, dims);
+}
+
void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_tensor* src = dst->src[0];
@@ -875,20 +955,39 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
- size_t one_tensor_n_bytes = src->ne[0] * ggml_element_size(src);
- ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes);
- aclTensor* acl_gamma = aclnn_values(
- ctx, one_tensor_allocator.get(), one_tensor_n_bytes, src->ne, 1,
- ggml_cann_type_mapping(src->type), ggml_element_size(src));
+ // build gamma, one...
+ size_t acl_gamma_nb[GGML_MAX_DIMS];
+ acl_gamma_nb[0] = sizeof(float);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1];
+ }
+ aclTensor* acl_gamma = get_f32_cache_acl_tensor(
+ ctx,
+ &ctx.f32_one_cache,
+ ctx.f32_one_cache_element,
+ src->ne,
+ acl_gamma_nb,
+ 1, // dims
+ 1.0f // value
+ );
+
+ // build rstd, zero...
+ size_t acl_rstd_nb[GGML_MAX_DIMS];
+ acl_rstd_nb[0] = sizeof(float);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ acl_rstd_nb[i] = acl_rstd_nb[i - 1] * src->ne[i - 1];
+ }
+ aclTensor* acl_rstd = get_f32_cache_acl_tensor(
+ ctx,
+ &ctx.f32_zero_cache,
+ ctx.f32_zero_cache_element,
+ src->ne,
+ acl_rstd_nb,
+ GGML_MAX_DIMS,
+ 0.0f // value
+ );
- size_t zero_tensor_n_bytes =
- src->ne[1] * src->ne[2] * src->ne[3] * ggml_element_size(src);
- ggml_cann_pool_alloc zero_tensor_allocator(ctx.pool(), zero_tensor_n_bytes);
- aclTensor* acl_rstd =
- aclnn_zero(ctx, zero_tensor_allocator.get(), zero_tensor_n_bytes,
- src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
- ggml_element_size(src));
GGML_CANN_CALL_ACLNN_OP(ctx, RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd);
ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_gamma, acl_rstd);
}
@@ -903,14 +1002,13 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst,
const int n_past = ((int32_t*)dst->op_params)[0];
- size_t one_tensor_n_bytes = src->ne[0] * src->ne[1] * src->ne[2] *
- src->ne[3] * ggml_element_size(src);
- ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes);
+ ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), ggml_nbytes(src));
+ void* buffer = one_tensor_allocator.get();
- aclTensor* mask_tensor =
- aclnn_values(ctx, one_tensor_allocator.get(), one_tensor_n_bytes,
- src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
- ggml_element_size(src), value);
+ aclTensor* mask_tensor = ggml_cann_create_tensor(buffer, ggml_cann_type_mapping(src->type),
+ ggml_type_size(src->type), src->ne, src->nb, GGML_MAX_DIMS);
+
+ aclnn_fill_scalar(ctx, value, mask_tensor);
aclScalar* alpha = nullptr;
float alphaValue = 1.0f;
@@ -1159,12 +1257,20 @@ static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) {
void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclTensor* acl_dst) {
- GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
+ if(acl_dst == nullptr) {
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCos, acl_src);
+ } else {
+ GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
+ }
}
void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclTensor* acl_dst) {
- GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
+ if(acl_dst == nullptr) {
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSin, acl_src);
+ } else {
+ GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
+ }
}
void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
@@ -1277,23 +1383,6 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
tmp_permute_tensor, tmp_mul_tensor, acl_dst);
}
-/**
- * @brief Fills a tensor with a scalar value.
- *
- * This function fills the destination tensor `acl_dst` with the scalar value
- * `scalar`.
- *
- * @param ctx The context for the CANN backend operations.
- * @param scalar The scalar value used to fill the tensor.
- * @param acl_dst The destination tensor to be filled with the scalar value.
- */
-static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
- aclTensor* acl_dst) {
- auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT);
- GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst, acl_scalar);
- ggml_cann_release_resources(ctx, acl_scalar);
-}
-
/**
* @brief Raises each element of a tensor to the power of the corresponding
* element in another tensor.
@@ -1338,17 +1427,17 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer,
float m, int64_t size, float start, float stop, float step){
int64_t ne[] = {size};
- size_t nb[] = {sizeof(float)};
+ size_t nb[] = {sizeof(uint16_t)};
- ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(float));
+ ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(uint16_t));
void* arange_buffer = arange_allocator.get();
aclTensor* arange_tensor = ggml_cann_create_tensor(
- arange_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
+ arange_buffer, ACL_FLOAT16, sizeof(uint16_t), ne, nb, 1);
aclnn_arange(ctx, arange_tensor, start, stop, step, size);
aclTensor* slope_tensor = ggml_cann_create_tensor(
- slope_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
+ slope_buffer, ACL_FLOAT16, sizeof(uint16_t), ne, nb, 1);
aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT);
@@ -2140,13 +2229,54 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
ggml_cann_release_resources(ctx, acl_index, acl_value);
}
+/**
+ * @brief Initializes and caches sine/cosine positional encoding values
+ * (used in RoPE, Rotary Position Embedding) for attention layers.
+ *
+ * This function computes and caches the sin/cos values of
+ * θ = position * theta_scale for RoPE encoding. The cache is shared
+ * across attention layers, and only the first attention layer will
+ * trigger initialization. The cache includes repeated sin/cos values
+ * with different repeat methods depending on the @param is_neox flag.
+ *
+ * Steps performed by this function:
+ * 1. Identify whether the target tensor belongs to Q/K in attention
+ * and restrict computation to the first layer only.
+ * 2. Initialize the theta scale array (arange → power → freq scaling).
+ * 3. Allocate sin/cos caches if the max prompt length increases.
+ * 4. Compute θ = position * theta_scale.
+ * 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
+ * 6. Expand sin/cos values by repeat or repeat_interleave depending
+ * on whether @param is_neox is enabled.
+ * 7. Store the computed values into persistent buffers
+ * (ctx.rope_sin_ptr / ctx.rope_cos_ptr).
+ *
+ * @param ctx The CANN backend context, holding memory pool,
+ * stream, and persistent buffers for rope init/cache.
+ * @param dst The destination ggml_tensor whose computation
+ * depends on the cached RoPE values (usually Qcur/Kcur).
+ * @param theta_scale Scalar exponent base for computing theta scale values.
+ * @param freq_scale Frequency scaling factor, applied to theta scale.
+ * @param attn_factor Attention scaling factor, applied to sin/cos.
+ * @param is_neox Whether to use Neox-style repeat strategy
+ * (dim expansion vs repeat_interleave).
+ */
static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
- aclTensor* acl_cos_repeat_tensor,
- aclTensor* acl_sin_repeat_tensor,
float theta_scale, float freq_scale,
float attn_factor, bool is_neox) {
// int sin/cos cache, cache has different repeat method depond on
// @param.is_neox
+ bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
+ bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
+
+ // used for accuracy testing
+ bool is_attention = is_q || is_k;
+
+ // just compute in first layer in attention
+ bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
+ if(is_attention && !is_fisrt_layer) {
+ return;
+ }
ggml_tensor* src0 = dst->src[0]; // input
ggml_tensor* src1 = dst->src[1]; // position
@@ -2172,21 +2302,16 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
}
- bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
- bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
-
- // used for accuracy testing
- bool is_attention = is_q || is_k;
-
- if(ctx.init_ptr == nullptr || !is_attention) {
+ // init theta scale, just one time
+ if(ctx.rope_init_ptr == nullptr || !is_attention) {
// theta_scale arange, [0,1,...,ne00/2 - 1]
- if(ctx.init_ptr != nullptr){
- ACL_CHECK(aclrtFree(ctx.init_ptr));
+ if(ctx.rope_init_ptr != nullptr){
+ ACL_CHECK(aclrtFree(ctx.rope_init_ptr));
}
- ACL_CHECK(aclrtMalloc(&ctx.init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
+ ACL_CHECK(aclrtMalloc(&ctx.rope_init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
aclTensor* acl_theta_scale_tensor =
- ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
+ ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
float start = 0;
float step = 1;
@@ -2216,67 +2341,55 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale);
}
- if(ctx.sin_ptr == nullptr) {
- int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
- ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
- ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
- }
+ // init sin_repeat && cos_repeat, one token just init in 0 layer
if(position_length > ctx.max_prompt_length) {
ctx.max_prompt_length = position_length;
- int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
- ACL_CHECK(aclrtFree(ctx.sin_ptr));
- ACL_CHECK(aclrtFree(ctx.cos_ptr));
- ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
- ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
+ int64_t repeat_theta_length = theta_scale_length * ctx.max_prompt_length * 2;
+ if(ctx.rope_sin_ptr != nullptr) {
+ ACL_CHECK(aclrtFree(ctx.rope_sin_ptr));
+ ACL_CHECK(aclrtFree(ctx.rope_cos_ptr));
+ }
+ ACL_CHECK(aclrtMalloc(&ctx.rope_sin_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
+ ACL_CHECK(aclrtMalloc(&ctx.rope_cos_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
}
- bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
-
- if(is_fisrt_layer || !is_attention) {
-
- aclTensor* acl_theta_scale_tensor =
- ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
+ aclTensor* acl_theta_scale_tensor =
+ ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
- // position
- aclTensor* acl_position_tensor = ggml_cann_create_tensor(
- src1->data, ggml_cann_type_mapping(src1->type),
- ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
+ // position
+ aclTensor* acl_position_tensor = ggml_cann_create_tensor(
+ src1->data, ggml_cann_type_mapping(src1->type),
+ ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
- // power * position
- int64_t theta_length = theta_scale_length * position_length;
- ggml_cann_pool_alloc theta_allocator(ctx.pool(),
- theta_length * sizeof(float_t));
- void* theta_buffer = theta_allocator.get();
+ // power * position
+ int64_t theta_length = theta_scale_length * position_length;
+ ggml_cann_pool_alloc theta_allocator(ctx.pool(),
+ theta_length * sizeof(float_t));
+ void* theta_buffer = theta_allocator.get();
- aclTensor* acl_theta_tensor =
- ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
- theta_ne, theta_nb, GGML_MAX_DIMS);
- aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
- acl_theta_tensor);
-
- // sin/cos
- aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
- ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
- GGML_MAX_DIMS, ACL_FORMAT_ND);
- aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
-
- aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
- ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
- GGML_MAX_DIMS, ACL_FORMAT_ND);
- aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
-
- // release
- ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
- acl_theta_tensor, acl_sin_tensor, acl_cos_tensor);
- }
+ aclTensor* acl_theta_tensor =
+ ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
+ theta_ne, theta_nb, GGML_MAX_DIMS);
+ aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
+ acl_theta_tensor);
+ // sin/cos
+ ggml_cann_pool_alloc sin_allocator(ctx.pool(),
+ theta_length * sizeof(float_t));
+ void* sin_buffer = sin_allocator.get();
aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
- ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
- GGML_MAX_DIMS, ACL_FORMAT_ND);
+ sin_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
+ GGML_MAX_DIMS, ACL_FORMAT_ND);
+ aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
+
+ ggml_cann_pool_alloc cos_allocator(ctx.pool(),
+ theta_length * sizeof(float_t));
+ void* cos_buffer = cos_allocator.get();
aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
- ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
- GGML_MAX_DIMS, ACL_FORMAT_ND);
+ cos_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
+ GGML_MAX_DIMS, ACL_FORMAT_ND);
+ aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
// attn_factor
if (attn_factor != 1) {
@@ -2284,6 +2397,19 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true);
}
+ int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
+ size_t sin_reshape_nb[GGML_MAX_DIMS];
+ sin_reshape_nb[0] = sizeof(float_t);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
+ }
+ aclTensor* acl_sin_repeat_tensor =
+ ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
+ sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
+ aclTensor* acl_cos_repeat_tensor =
+ ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
+ sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
+
// repeat
if (is_neox) {
int64_t repeatsArray[] = {1, 1, 1, 2};
@@ -2299,8 +2425,9 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
num_repeats, output_size);
}
- // release
- ggml_cann_release_resources(ctx, acl_sin_tensor, acl_cos_tensor);
+ ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
+ acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor,
+ acl_cos_repeat_tensor);
}
#ifdef __cplusplus
@@ -2354,13 +2481,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
- // init cos/sin cache
- ggml_cann_pool_alloc sin_allocator(
- ctx.pool(), ne00 * ne02 * sizeof(float_t));
- ggml_cann_pool_alloc cos_allocator(
- ctx.pool(), ne00 * ne02 * sizeof(float_t));
- void* sin_buffer = sin_allocator.get();
- void* cos_buffer = cos_allocator.get();
+ // init ctx.rope_cos/rope_sin cache
+ aclnn_cache_init(ctx, dst, theta_scale, freq_scale, attn_factor, is_neox);
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
size_t sin_reshape_nb[GGML_MAX_DIMS];
@@ -2369,13 +2491,11 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
}
aclTensor* acl_sin_reshape_tensor =
- ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float_t),
+ ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_cos_reshape_tensor =
- ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t),
+ ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
- aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor,
- theta_scale, freq_scale, attn_factor, is_neox);
aclTensor* acl_src = ggml_cann_create_tensor(src0);
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
@@ -3060,11 +3180,38 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
- ggml_tensor* src0 = dst->src[0]; // q, fp32
- ggml_tensor* src1 = dst->src[1]; // k, fp16
- ggml_tensor* src2 = dst->src[2]; // v, fp16
+ ggml_tensor* src0 = dst->src[0]; // q, fp32 | B, N, S, D (uncont) -> B, S, N, D (cont)
+ ggml_tensor* src1 = dst->src[1]; // k, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
+ ggml_tensor* src2 = dst->src[2]; // v, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
ggml_tensor* src3 = dst->src[3]; // mask, fp16
+ // B, N, S, D (uncont) -> B, S, N, D (cont)
+ int64_t src0_bsnd_ne[GGML_MAX_DIMS];
+ memcpy(src0_bsnd_ne, src0->ne, GGML_MAX_DIMS * sizeof(int64_t));
+ size_t src0_bsnd_nb[GGML_MAX_DIMS];
+ memcpy(src0_bsnd_nb, src0->nb, GGML_MAX_DIMS * sizeof(size_t));
+ int64_t src1_bsnd_ne[GGML_MAX_DIMS];
+ memcpy(src1_bsnd_ne, src1->ne, GGML_MAX_DIMS * sizeof(int64_t));
+ size_t src1_bsnd_nb[GGML_MAX_DIMS];
+ memcpy(src1_bsnd_nb, src1->nb, GGML_MAX_DIMS * sizeof(size_t));
+ int64_t src2_bsnd_ne[GGML_MAX_DIMS];
+ memcpy(src2_bsnd_ne, src2->ne, GGML_MAX_DIMS * sizeof(int64_t));
+ size_t src2_bsnd_nb[GGML_MAX_DIMS];
+ memcpy(src2_bsnd_nb, src2->nb, GGML_MAX_DIMS * sizeof(size_t));
+
+ auto transpose12 = [](int64_t* ne, size_t* nb) {
+ int64_t ne_tmp = ne[1];
+ size_t nb_tmp = nb[1];
+ ne[1] = ne[2];
+ nb[1] = nb[2];
+ ne[2] = ne_tmp;
+ nb[2] = nb_tmp;
+ };
+
+ transpose12(src0_bsnd_ne, src0_bsnd_nb);
+ transpose12(src1_bsnd_ne, src1_bsnd_nb);
+ transpose12(src2_bsnd_ne, src2_bsnd_nb);
+
float maxBias = 0.0f;
float scaleValue = 1.0f;
float logitSoftcap = 0.0f;
@@ -3086,11 +3233,12 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
void* src0_f16_buffer = nullptr;
if(ggml_cann_type_mapping(src0->type) != faDataType){
- aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0);
+ aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne,
+ src0_bsnd_nb, GGML_MAX_DIMS);
src0_f16_buffer = src0_f16_allocator.alloc(
ggml_nelements(src0) * faElemSize);
- int64_t* src0_f16_ne = src0->ne;
+ int64_t* src0_f16_ne = src0_bsnd_ne;
size_t src0_f16_nb[GGML_MAX_DIMS];
src0_f16_nb[0] = sizeof(uint16_t);
for(int i = 1; i < GGML_MAX_DIMS; ++i){
@@ -3104,20 +3252,23 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType);
ggml_cann_release_resources(ctx, acl_src0_f32_tensor);
}else{
- acl_src0_f16_tensor = ggml_cann_create_tensor(src0);
+ acl_src0_f16_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne,
+ src0_bsnd_nb, GGML_MAX_DIMS);
}
// Step 2: create the acl tensors for src1 (Key), src2 (Value),
// and the direct output from FusedInferAttention
- acl_src1_f16_tensor = ggml_cann_create_tensor(src1);
- acl_src2_f16_tensor = ggml_cann_create_tensor(src2);
+ acl_src1_f16_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne,
+ src1_bsnd_nb, GGML_MAX_DIMS);
+ acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne,
+ src2_bsnd_nb, GGML_MAX_DIMS);
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
void* out_f16_buffer = out_f16_allocator.alloc(
ggml_nelements(dst) * faElemSize);
- int64_t* out_f16_ne = src0->ne;
+ int64_t* out_f16_ne = src0_bsnd_ne;
size_t out_f16_nb[GGML_MAX_DIMS];
out_f16_nb[0] = faElemSize;
for(int i = 1; i < GGML_MAX_DIMS; ++i){
@@ -3131,88 +3282,81 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
// Step 3: create the PSEShift tensor if needed
// this tensor is considered as mask (f16) in the llama.cpp
-
aclTensor* bcast_pse_tensor = nullptr;
- int64_t bcast_pse_ne[GGML_MAX_DIMS];
- size_t bcast_pse_nb[GGML_MAX_DIMS];
ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool());
- void* bcast_pse_buffer = nullptr;
-
if(src3 != nullptr){
- bcast_pse_buffer = bcast_pse_allocator.alloc(
- ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t));
+ // Construct the truncated pse tensor (common for prefill/decode)
+ int64_t trunc_pse_ne[GGML_MAX_DIMS] = {
+ src3->ne[0], // D
+ src0->ne[1], // S (number of Q tokens)
+ src3->ne[2], // mask N
+ src3->ne[3] // B
+ };
+ size_t* trunc_pse_nb = src3->nb;
- if(src0->ne[1] > 1){
- // Case 1: broadcast pse for prefill stage with multiple head
- aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3);
- bcast_pse_ne[0] = src3->ne[0];
- bcast_pse_ne[1] = src3->ne[1];
- bcast_pse_ne[2] = src0->ne[2];
- bcast_pse_ne[3] = src3->ne[3];
+ aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
+ src3->data, ACL_FLOAT16, sizeof(uint16_t),
+ trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS
+ );
+ int64_t bcast_pse_ne[GGML_MAX_DIMS];
+ size_t bcast_pse_nb[GGML_MAX_DIMS];
+ bcast_pse_ne[0] = src3->ne[0]; // D
+ bcast_pse_ne[1] = src0->ne[1]; // S
+ bcast_pse_ne[2] = src0->ne[2]; // N (num_heads)
+ bcast_pse_ne[3] = src3->ne[3]; // B
+ if (maxBias == 0.0f) {
+ // When maxBias == 0.0f, use nb = 0 reduce once repeat (Qwen2)
+ // Construct the bcast tensor (simulate repeat on the head dimension using stride=0)
bcast_pse_nb[0] = sizeof(uint16_t);
- for(int i = 1; i < GGML_MAX_DIMS; ++i){
- bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
- }
+ bcast_pse_nb[1] = bcast_pse_nb[0] * bcast_pse_ne[0];
+ bcast_pse_nb[2] = 0; // <---- the head dimension shares the same data
+ bcast_pse_nb[3] = src3->nb[3];
bcast_pse_tensor = ggml_cann_create_tensor(
- bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
- bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
-
- int64_t repeats[] = {1, src0->ne[2], 1, 1};
- aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats);
-
- ggml_cann_release_resources(ctx, acl_mask_f16_tensor);
- }else{
- // Case 2: trunc the first row and broadcast pse for decode stage with multiple head
- int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]};
- size_t* trunc_pse_nb = src3->nb;
-
- aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
src3->data, ACL_FLOAT16, sizeof(uint16_t),
- trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);
-
- bcast_pse_ne[0] = src3->ne[0];
- bcast_pse_ne[1] = src0->ne[1];
- bcast_pse_ne[2] = src0->ne[2];
- bcast_pse_ne[3] = src3->ne[3];
+ bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
+ );
+ ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
+ } else {
bcast_pse_nb[0] = sizeof(uint16_t);
- for(int i = 1; i < GGML_MAX_DIMS; ++i){
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
}
+ void* bcast_pse_buffer = bcast_pse_allocator.alloc(
+ ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)
+ );
+
bcast_pse_tensor = ggml_cann_create_tensor(
bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
- bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
+ bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
+ );
int64_t repeats[] = {1, src0->ne[2], 1, 1};
aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats);
- ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
- }
-
- // Compute the slope if needed. Derived from ggml_cann_softmax().
- if(maxBias != 0.0f){
// alibi
+ // Compute the slope if needed. Derived from ggml_cann_softmax().
const int64_t n_heads = src0->ne[2];
- ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));
+ ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(uint16_t));
void* slope_buffer = slope_allocator.get();
aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias);
int64_t slope_ne[] = {1, 1, n_heads, 1};
size_t slope_nb[GGML_MAX_DIMS];
- slope_nb[0] = sizeof(float);
+ slope_nb[0] = sizeof(uint16_t);
for(int i = 1;ine[0]); // 1/sqrt(d)
int64_t preTokens = 65535;
int64_t nextTokens = 65535;
- char layout[5] = {'B', 'N', 'S', 'D', 0};
+ char layout[5] = {'B', 'S', 'N', 'D', 0};
int64_t sparseMode = 0;
int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2;
int64_t blockSize = 0;
@@ -3266,32 +3410,9 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
);
// Step 6: post-processing, permute and cast to f32
-
- int64_t new_dim[] = {0, 2, 1, 3};
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
-
- if(ggml_cann_type_mapping(dst->type) != faDataType){
- ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool());
- perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
- void* perm_out_f16_buffer = perm_out_f16_allocator.get();
-
- int64_t* perm_out_f16_ne = dst->ne;
- size_t perm_out_f16_nb[GGML_MAX_DIMS];
- perm_out_f16_nb[0] = faElemSize;
- for(int i = 1; i < GGML_MAX_DIMS; ++i){
- perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1];
- }
- aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor(
- perm_out_f16_buffer, faDataType, faElemSize,
- perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
- aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
- aclnn_cast(ctx,
- acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
- ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor);
- }else{
- // only need to permute
- aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
- }
+ // TODO: when dst is fp16, don't need cast
+ aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
acl_src1_f16_tensor,
acl_src2_f16_tensor,
diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h
index 2c2033bfba..88cc3f481e 100755
--- a/ggml/src/ggml-cann/common.h
+++ b/ggml/src/ggml-cann/common.h
@@ -368,17 +368,22 @@ struct ggml_backend_cann_context {
std::string name; /**< Name of the device. */
std::string description; /**< Description of the device. */
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
- void* init_ptr = nullptr;
- void* sin_ptr = nullptr;
- void* cos_ptr = nullptr;
- int64_t max_prompt_length = 65536;
#ifdef USE_ACL_GRAPH
/// Cached CANN ACL graph used for executing the current ggml computation graph.
std::unique_ptr cann_graph;
#endif
cann_task_queue task_queue;
bool async_mode;
- bool support_set_rows;
+ // Rope Cache
+ void* rope_init_ptr = nullptr;
+ void* rope_sin_ptr = nullptr;
+ void* rope_cos_ptr = nullptr;
+ int64_t max_prompt_length = 0;
+ // Constant Pool
+ void* f32_zero_cache = nullptr;
+ void* f32_one_cache = nullptr;
+ int64_t f32_zero_cache_element = 0;
+ int64_t f32_one_cache_element = 0;
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
@@ -394,14 +399,6 @@ struct ggml_backend_cann_context {
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
device, async_mode ? "ON" : "OFF");
-
- support_set_rows = parse_bool(get_env("LLAMA_SET_ROWS").value_or(""));
- GGML_LOG_INFO("%s: LLAMA_SET_ROWS is %s\n", __func__, support_set_rows ? "ON" : "OFF");
-
- if (!support_set_rows) {
- GGML_LOG_INFO("%s: CANN Graph currently only supports execution when LLAMA_SET_ROWS is ON. "
- "Falling back to eager mode.\n", __func__);
- }
}
/**
@@ -418,14 +415,20 @@ struct ggml_backend_cann_context {
ACL_CHECK(aclrtDestroyStream(streams[i]));
}
}
- if(init_ptr != nullptr) {
- ACL_CHECK(aclrtFree(init_ptr));
+ if(rope_init_ptr != nullptr) {
+ ACL_CHECK(aclrtFree(rope_init_ptr));
}
- if(sin_ptr != nullptr) {
- ACL_CHECK(aclrtFree(sin_ptr));
+ if(rope_sin_ptr != nullptr) {
+ ACL_CHECK(aclrtFree(rope_sin_ptr));
}
- if(cos_ptr != nullptr) {
- ACL_CHECK(aclrtFree(cos_ptr));
+ if(rope_cos_ptr != nullptr) {
+ ACL_CHECK(aclrtFree(rope_cos_ptr));
+ }
+ if(f32_zero_cache != nullptr) {
+ ACL_CHECK(aclrtFree(f32_zero_cache));
+ }
+ if(f32_one_cache != nullptr) {
+ ACL_CHECK(aclrtFree(f32_one_cache));
}
}
diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp
index cb8af42ebf..7b3aca9db9 100755
--- a/ggml/src/ggml-cann/ggml-cann.cpp
+++ b/ggml/src/ggml-cann/ggml-cann.cpp
@@ -1155,7 +1155,7 @@ namespace {
* @note The workspace buffer used in this function is managed globally and reused
* across calls. This reduces overhead from repeated memory allocation and deallocation.
*/
-static void weight_format_to_nz(ggml_tensor *tensor, const void *data, size_t offset) {
+static void weight_format_to_nz(ggml_tensor *tensor, size_t offset) {
aclTensor* weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne,
tensor->nb, 2, ACL_FORMAT_ND, offset);
uint64_t workspaceSize = 0;
@@ -1203,7 +1203,7 @@ static void ggml_backend_cann_buffer_set_tensor(
if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) {
GGML_ASSERT(tensor->ne[2] == 1);
GGML_ASSERT(tensor->ne[3] == 1);
- weight_format_to_nz(tensor, data, offset);
+ weight_format_to_nz(tensor, offset);
}
} else {
void *transform_buffer = malloc(size);
@@ -2251,11 +2251,6 @@ static enum ggml_status ggml_backend_cann_graph_compute(
bool use_cann_graph = true;
bool cann_graph_update_required = false;
- // check environment LLAMA_SET_ROWS
- if (!cann_ctx->support_set_rows) {
- use_cann_graph = false;
- }
-
if (use_cann_graph) {
if (cann_ctx->cann_graph == nullptr) {
cann_ctx->cann_graph.reset(new ggml_cann_graph());
@@ -2336,7 +2331,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
#ifdef ASCEND_310P
- // Q4 && Q8 per group is not suppor on 310p device
+ // Q4 && Q8 per group is not support on 310p device
return false;
#endif
// only support contiguous for quantized types.
@@ -2354,7 +2349,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
#ifdef ASCEND_310P
- // Q4 && Q8 per group is not suppor on 310p device
+ // Q4 && Q8 per group is not support on 310p device
return false;
#endif
// only support contiguous for quantized types.
@@ -2496,7 +2491,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
return true;
case GGML_OP_SCALE:
float bias;
- memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
+ memcpy(&bias, (const float *)(op->op_params) + 1, sizeof(float));
return bias == 0.0f; // TODO: support bias != 0.0f
case GGML_OP_SOFT_MAX:
// TODO: support attention sinks [TAG_ATTN_SINKS]
@@ -2505,6 +2500,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
}
return true;
case GGML_OP_FLASH_ATTN_EXT:{
+#ifdef ASCEND_310P
+ // FA not support on 310p device
+ return false;
+#endif
// derived from [ggml-cuda.cu]
if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
return false;
@@ -2530,8 +2529,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
// DeepSeek MLA
return false;
}
+ if (op->src[0]->ne[0] % 16 != 0) {
+ // TODO: padding to support
+ return false;
+ }
float logitSoftcap = 0.0f;
- memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float));
+ memcpy(&logitSoftcap, (const float *)(op->op_params) + 2, sizeof(float));
if(logitSoftcap != 0.0f) {
return false;
}
diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt
index ce0a3e1285..b70302ec8c 100644
--- a/ggml/src/ggml-cpu/CMakeLists.txt
+++ b/ggml/src/ggml-cpu/CMakeLists.txt
@@ -435,7 +435,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
)
if (GGML_RVV)
if (GGML_XTHEADVECTOR)
- list(APPEND ARCH_FLAGS -march=rv64gc_xtheadvector -mabi=lp64d)
+ list(APPEND ARCH_FLAGS -march=rv64gc_zfhmin_xtheadvector -mabi=lp64d)
elseif (GGML_RV_ZFH)
list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -mabi=lp64d)
else()
diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h
index 0bfb92df17..373408a9c0 100644
--- a/ggml/src/ggml-cpu/arch-fallback.h
+++ b/ggml/src/ggml-cpu/arch-fallback.h
@@ -150,8 +150,6 @@
#elif defined(__s390x__)
// quants.c
#define quantize_row_q8_K_generic quantize_row_q8_K
-#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
-#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
diff --git a/ggml/src/ggml-cpu/arch/s390/quants.c b/ggml/src/ggml-cpu/arch/s390/quants.c
index 7e4229d0e4..1c8176fb4d 100644
--- a/ggml/src/ggml-cpu/arch/s390/quants.c
+++ b/ggml/src/ggml-cpu/arch/s390/quants.c
@@ -23,6 +23,27 @@
#define UNUSED GGML_UNUSED
+#if defined(__VXE__) || defined(__VXE2__)
+#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
+#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
+#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
+#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
+#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
+#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
+#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
+#define B8(c,s ) B7(c,s, c), B7(c,s, s)
+
+// precomputed tables for expanding 8bits to 8 bytes:
+static const __attribute__((aligned(16))) uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b ) << 4
+static const __attribute__((aligned(16))) uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
+
+// permute mask for byteswapping
+static const uint8x16_t v_kperm = (const uint8x16_t){
+ 7, 6, 5, 4, 3, 2, 1, 0,
+ 15, 14, 13, 12, 11, 10, 9, 8
+};
+#endif
+
void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(QK8_0 == 32);
assert(k % QK8_0 == 0);
@@ -241,6 +262,301 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
#endif
}
+void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+ const int qk = QK8_0;
+ const int nb = n / qk;
+
+ assert(n % qk == 0);
+ assert(qk == QK5_0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const block_q5_0 * GGML_RESTRICT x = vx;
+ const block_q8_0 * GGML_RESTRICT y = vy;
+
+ int ib = 0;
+ float sumf = 0.0f;
+
+#if defined(__VXE__) || defined(__VXE2__)
+ float32x4_t v_sum0 = vec_splats(0.0f);
+ float32x4_t v_sum1 = vec_splats(0.0f);
+
+ uint32_t qh0, qh1;
+ uint64_t tmp0[4], tmp1[4];
+
+ const uint8x16_t v_m = vec_splats((uint8_t)0x0F);
+
+ #pragma GCC unroll 4
+ for (; ib + 1 < nb; ib += 2) {
+ const block_q5_0 * GGML_RESTRICT x0 = &x[ib + 0];
+ const block_q5_0 * GGML_RESTRICT x1 = &x[ib + 1];
+ const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
+ const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
+
+ memcpy(&qh0, x0->qh, sizeof(qh0));
+ memcpy(&qh1, x1->qh, sizeof(qh1));
+
+ tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF];
+ tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF];
+ tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
+ tmp0[3] = table_b2b_1[(qh0 >> 24) ];
+
+ tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF];
+ tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF];
+ tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
+ tmp1[3] = table_b2b_1[(qh1 >> 24) ];
+
+ int8x16_t v_qh0l = vec_xl(0, (const int8_t *)(tmp0 + 0));
+ int8x16_t v_qh0h = vec_xl(0, (const int8_t *)(tmp0 + 2));
+ int8x16_t v_qh1l = vec_xl(0, (const int8_t *)(tmp1 + 0));
+ int8x16_t v_qh1h = vec_xl(0, (const int8_t *)(tmp1 + 2));
+
+ // required for fixing the byteorder
+ v_qh0l = vec_perm(v_qh0l, v_qh0l, v_kperm);
+ v_qh0h = vec_perm(v_qh0h, v_qh0h, v_kperm);
+ v_qh1l = vec_perm(v_qh1l, v_qh1l, v_kperm);
+ v_qh1h = vec_perm(v_qh1h, v_qh1h, v_kperm);
+
+ const uint8x16_t v_x0 = vec_xl(0, (const uint8_t *)x0->qs);
+ const uint8x16_t v_x1 = vec_xl(0, (const uint8_t *)x1->qs);
+
+ int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);
+ int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);
+ int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);
+ int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);
+
+ const int8x16_t v_x0lf = vec_sub(v_x0l, v_qh0l);
+ const int8x16_t v_x0hf = vec_sub(v_x0h, v_qh0h);
+ const int8x16_t v_x1lf = vec_sub(v_x1l, v_qh1l);
+ const int8x16_t v_x1hf = vec_sub(v_x1h, v_qh1h);
+
+ const int8x16_t v_y0l = vec_xl(0, (const int8_t *)y0->qs);
+ const int8x16_t v_y0h = vec_xl(QK8_0/2, (const int8_t *)y0->qs);
+ const int8x16_t v_y1l = vec_xl(0, (const int8_t *)y1->qs);
+ const int8x16_t v_y1h = vec_xl(QK8_0/2, (const int8_t *)y1->qs);
+
+ const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0lf, v_y0l), v_x0hf, v_y0h);
+ const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1lf, v_y1l), v_x1hf, v_y1h);
+
+ const float32x4_t v_xy0f = vec_float(v_xy0);
+ const float32x4_t v_xy1f = vec_float(v_xy1);
+
+ const float32x4_t v_d0 = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));
+ const float32x4_t v_d1 = vec_splats(GGML_CPU_FP16_TO_FP32(x1->d) * GGML_CPU_FP16_TO_FP32(y1->d));
+
+ v_sum0 = vec_madd(v_xy0f, v_d0, v_sum0);
+ v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1);
+ }
+
+ sumf += vec_hsum(v_sum0) + vec_hsum(v_sum1);
+
+ #pragma GCC unroll 4
+ for (; ib < nb; ++ib) {
+ const block_q5_0 * GGML_RESTRICT x0 = &x[ib];
+ const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
+
+ uint32_t qh;
+ memcpy(&qh, x0->qh, sizeof(qh));
+
+ uint64_t tmp[4];
+ tmp[0] = table_b2b_1[(qh >> 0) & 0xFF];
+ tmp[1] = table_b2b_1[(qh >> 8) & 0xFF];
+ tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
+ tmp[3] = table_b2b_1[(qh >> 24) ];
+
+ int8x16_t v_qhl = vec_xl(0, (const int8_t *)(tmp + 0));
+ int8x16_t v_qhh = vec_xl(0, (const int8_t *)(tmp + 2));
+
+ // required for fixing the byteorder
+ v_qhl = vec_perm(v_qhl, v_qhl, v_kperm);
+ v_qhh = vec_perm(v_qhh, v_qhh, v_kperm);
+
+ const uint8x16_t v_x = vec_xl(0, (const uint8_t *)x0->qs);
+ int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
+ int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);
+
+ const int8x16_t v_xlf = vec_sub(v_xl, v_qhl);
+ const int8x16_t v_xhf = vec_sub(v_xh, v_qhh);
+
+ const int8x16_t v_yl = vec_xl(0, (const int8_t *)y0->qs);
+ const int8x16_t v_yh = vec_xl(QK8_0/2, (const int8_t *)y0->qs);
+
+ const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xlf, v_yl), v_xhf, v_yh);
+ const float32x4_t v_xyf = vec_float(v_xy);
+
+ const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));
+ const float32x4_t v_acc = vec_madd(v_xyf, v_d, vec_splats(0.0f));
+
+ sumf += vec_hsum(v_acc);
+ }
+
+ *s = sumf;
+#else
+ UNUSED(nb);
+ UNUSED(x);
+ UNUSED(y);
+ UNUSED(ib);
+ UNUSED(sumf);
+ ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+ const int qk = QK8_1;
+ const int nb = n / qk;
+
+ assert(n % qk == 0);
+ assert(qk == QK5_1);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const block_q5_1 * GGML_RESTRICT x = vx;
+ const block_q8_1 * GGML_RESTRICT y = vy;
+
+ int ib = 0;
+ float sumf = 0.0f;
+
+#if defined(__VXE__) || defined(__VXE2__)
+ float32x4_t v_sum0 = vec_splats(0.0f);
+ float32x4_t v_sum1 = vec_splats(0.0f);
+
+ float summs0 = 0.0f;
+ float summs1 = 0.0f;
+
+ uint32_t qh0;
+ uint32_t qh1;
+
+ uint64_t tmp0[4];
+ uint64_t tmp1[4];
+
+ const uint8x16_t v_m = vec_splats((uint8_t)0x0F);
+
+ #pragma GCC unroll 4
+ for (; ib + 1 < nb; ib += 2) {
+ const block_q5_1 * GGML_RESTRICT x0 = &x[ib + 0];
+ const block_q5_1 * GGML_RESTRICT x1 = &x[ib + 1];
+ const block_q8_1 * GGML_RESTRICT y0 = &y[ib + 0];
+ const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1];
+
+ summs0 += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s);
+ summs1 += GGML_CPU_FP16_TO_FP32(x1->m) * GGML_CPU_FP16_TO_FP32(y1->s);
+
+ memcpy(&qh0, x0->qh, sizeof(qh0));
+ memcpy(&qh1, x1->qh, sizeof(qh1));
+
+ tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF];
+ tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF];
+ tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
+ tmp0[3] = table_b2b_0[(qh0 >> 24) ];
+
+ tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF];
+ tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF];
+ tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
+ tmp1[3] = table_b2b_0[(qh1 >> 24) ];
+
+ int8x16_t v_qh0l = vec_xl(0, (const int8_t *)(tmp0 + 0));
+ int8x16_t v_qh0h = vec_xl(0, (const int8_t *)(tmp0 + 2));
+ int8x16_t v_qh1l = vec_xl(0, (const int8_t *)(tmp1 + 0));
+ int8x16_t v_qh1h = vec_xl(0, (const int8_t *)(tmp1 + 2));
+
+ // required for fixing the byteorder
+ v_qh0l = vec_perm(v_qh0l, v_qh0l, v_kperm);
+ v_qh0h = vec_perm(v_qh0h, v_qh0h, v_kperm);
+ v_qh1l = vec_perm(v_qh1l, v_qh1l, v_kperm);
+ v_qh1h = vec_perm(v_qh1h, v_qh1h, v_kperm);
+
+ const uint8x16_t v_x0 = vec_xl(0, x0->qs);
+ const uint8x16_t v_x1 = vec_xl(0, x1->qs);
+
+ const int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);
+ const int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);
+ const int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);
+ const int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);
+
+ const int8x16_t v_x0lf = vec_or(v_x0l, v_qh0l);
+ const int8x16_t v_x0hf = vec_or(v_x0h, v_qh0h);
+ const int8x16_t v_x1lf = vec_or(v_x1l, v_qh1l);
+ const int8x16_t v_x1hf = vec_or(v_x1h, v_qh1h);
+
+ const int8x16_t v_y0l = vec_xl(0 , y0->qs);
+ const int8x16_t v_y0h = vec_xl(QK8_1/2, y0->qs);
+ const int8x16_t v_y1l = vec_xl(0 , y1->qs);
+ const int8x16_t v_y1h = vec_xl(QK8_1/2, y1->qs);
+
+ const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0lf, v_y0l), v_x0hf, v_y0h);
+ const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1lf, v_y1l), v_x1hf, v_y1h);
+
+ const float32x4_t v_xy0f = vec_float(v_xy0);
+ const float32x4_t v_xy1f = vec_float(v_xy1);
+
+ const float32x4_t v_d0 = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));
+ const float32x4_t v_d1 = vec_splats(GGML_CPU_FP16_TO_FP32(x1->d) * GGML_CPU_FP16_TO_FP32(y1->d));
+
+ v_sum0 = vec_madd(v_xy0f, v_d0, v_sum0);
+ v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1);
+ }
+
+ sumf += vec_hsum(v_sum0) + vec_hsum(v_sum1) + summs0 + summs1;
+
+ #pragma GCC unroll 4
+ for (; ib < nb; ++ib) {
+ const block_q5_1 * GGML_RESTRICT x0 = &x[ib];
+ const block_q8_1 * GGML_RESTRICT y0 = &y[ib];
+
+ float summs = GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s);
+
+ uint32_t qh;
+ memcpy(&qh, x0->qh, sizeof(qh));
+
+ uint64_t tmp[4];
+ tmp[0] = table_b2b_0[(qh >> 0) & 0xFF];
+ tmp[1] = table_b2b_0[(qh >> 8) & 0xFF];
+ tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
+ tmp[3] = table_b2b_0[(qh >> 24) ];
+
+ int8x16_t v_qhl = vec_xl(0, (const int8_t *)(tmp + 0));
+ int8x16_t v_qhh = vec_xl(0, (const int8_t *)(tmp + 2));
+
+ // required for fixing the byteorder
+ v_qhl = vec_perm(v_qhl, v_qhl, v_kperm);
+ v_qhh = vec_perm(v_qhh, v_qhh, v_kperm);
+
+ const uint8x16_t v_x = vec_xl(0, x0->qs);
+ const int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
+ const int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);
+
+ const int8x16_t v_xlf = vec_or(v_xl, v_qhl);
+ const int8x16_t v_xhf = vec_or(v_xh, v_qhh);
+
+ const int8x16_t v_yl = vec_xl(0 , y0->qs);
+ const int8x16_t v_yh = vec_xl(QK8_1/2, y0->qs);
+
+ const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xlf, v_yl), v_xhf, v_yh);
+ const float32x4_t v_xyf = vec_float(v_xy);
+
+ const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));
+ const float32x4_t v_acc = vec_madd(v_xyf, v_d, v_acc);
+
+ sumf += vec_hsum(v_acc) + summs;
+ }
+
+ *s = sumf;
+#else
+ UNUSED(nb);
+ UNUSED(x);
+ UNUSED(y);
+ UNUSED(ib);
+ UNUSED(sumf);
+ ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h
index d839cf5c55..e08c30a348 100644
--- a/ggml/src/ggml-cpu/ggml-cpu-impl.h
+++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h
@@ -486,6 +486,14 @@ inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) {
return v_abo + v_abe;
}
+/**
+ * @see https://github.com/ggml-org/llama.cpp/pull/14037
+ */
+inline static float vec_hsum(float32x4_t v) {
+ float32x4_t v_temp = v + vec_reve(v);
+ return v_temp[0] + v_temp[1];
+}
+
inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b);
return acc + (vec_unpackh(p) + vec_unpackl(p));
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index f6bea3df34..0d5d3a3440 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -1880,6 +1880,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_conv_2d(params, tensor);
} break;
+ case GGML_OP_CONV_3D:
+ {
+ ggml_compute_forward_conv_3d(params, tensor);
+ } break;
case GGML_OP_CONV_2D_DW:
{
ggml_compute_forward_conv_2d_dw(params, tensor);
@@ -2252,6 +2256,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_IM2COL:
case GGML_OP_IM2COL_BACK:
case GGML_OP_CONV_2D:
+ case GGML_OP_CONV_3D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_CONV_TRANSPOSE_2D:
@@ -2773,6 +2778,7 @@ struct ggml_cplan ggml_graph_plan(
}
} break;
case GGML_OP_CONV_2D:
+ case GGML_OP_CONV_3D:
{
cur = GGML_IM2COL_WORK_SIZE;
} break;
diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp
index 2be54c31b5..2c4ad9d58b 100644
--- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp
+++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp
@@ -2169,94 +2169,117 @@ class tinyBLAS_Q0_PPC {
class tinyBLAS_PPC {
public:
tinyBLAS_PPC(int64_t k,
- const float *A, int64_t lda,
- const float *B, int64_t ldb,
- float *C, int64_t ldc,
+ const float * A, int64_t lda,
+ const float * B, int64_t ldb,
+ float * C, int64_t ldc,
int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
void matmul(int64_t m, int64_t n) {
- mnpack(0, m, 0, n);
+ int64_t mc = 256; int64_t nc = 256; int64_t kc = 256;
+ if (m % mc == 0 && n % nc == 0 && k % kc == 0) {
+ matmul_tiled(m, n, mc, nc, kc);
+ } else {
+ mnpack(0, m, 0, n);
+ }
}
private:
- void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
-
- inline void vector_permute_store_4(vector float *src, float *vecOffset) {
- vector float t1, t2, t3, t4, t5, t6, t7, t8;
- t1 = vec_mergeh(src[0], src[1]);
- t2 = vec_mergeh(src[2], src[3]);
- t3 = vec_mergel(src[0], src[1]);
- t4 = vec_mergel(src[2], src[3]);
-
- t5 = vec_xxpermdi(t1, t2, 0);
- t6 = vec_xxpermdi(t1, t2, 3);
- t7 = vec_xxpermdi(t3, t4, 0);
- t8 = vec_xxpermdi(t3, t4, 3);
-
- vec_xst(t5, 0, vecOffset);
- vec_xst(t6, 0, vecOffset + 4);
- vec_xst(t7, 0, vecOffset + 8);
- vec_xst(t8, 0, vecOffset + 12);
- }
-
- inline void vector_permute_store_8(vector float *src, float *vecOffset) {
- vector float t1, t2, t3, t4, t5, t6, t7, t8;
- t1 = vec_mergeh(src[0], src[1]);
- t2 = vec_mergeh(src[2], src[3]);
- t3 = vec_mergeh(src[4], src[5]);
- t4 = vec_mergeh(src[6], src[7]);
-
- t5 = vec_xxpermdi(t1, t2, 0);
- t6 = vec_xxpermdi(t3, t4, 0);
- t7 = vec_xxpermdi(t1, t2, 3);
- t8 = vec_xxpermdi(t3, t4, 3);
-
- vec_xst(t5, 0, vecOffset);
- vec_xst(t6, 0, vecOffset + 4);
- vec_xst(t7, 0, vecOffset + 8);
- vec_xst(t8, 0, vecOffset + 12);
-
- t1 = vec_mergel(src[0], src[1]);
- t2 = vec_mergel(src[2], src[3]);
- t3 = vec_mergel(src[4], src[5]);
- t4 = vec_mergel(src[6], src[7]);
-
- t5 = vec_xxpermdi(t1, t2, 0);
- t6 = vec_xxpermdi(t3, t4, 0);
- t7 = vec_xxpermdi(t1, t2, 3);
- t8 = vec_xxpermdi(t3, t4, 3);
-
- vec_xst(t5, 0, vecOffset + 16);
- vec_xst(t6, 0, vecOffset + 20);
- vec_xst(t7, 0, vecOffset + 24);
- vec_xst(t8, 0, vecOffset + 28);
+ inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
+ vec_t vec_C[4];
+ __builtin_mma_disassemble_acc(vec_C, ACC);
+ for (int I = 0; I < 4; I++) {
+ for (int J = 0; J < 4; J++) {
+ *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
+ }
+ }
}
- void packTranspose(const float* a, int64_t lda, int rows, int cols, float* vec) {
+ inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
+ vec_t vec_C[4];
+ __builtin_mma_disassemble_acc(vec_C, ACC);
+ for (int I = 0; I < 4; I++) {
+ for (int J = 0; J < 4; J++) {
+ float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
+ *c_ptr += *((float *)&vec_C[I]+J);
+ }
+ }
+ }
+
+ inline void vector_permute_store_4(vector float * src, float * vecOffset) {
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
+ t1 = vec_mergeh(src[0], src[1]);
+ t2 = vec_mergeh(src[2], src[3]);
+ t3 = vec_mergel(src[0], src[1]);
+ t4 = vec_mergel(src[2], src[3]);
+
+ t5 = vec_xxpermdi(t1, t2, 0);
+ t6 = vec_xxpermdi(t1, t2, 3);
+ t7 = vec_xxpermdi(t3, t4, 0);
+ t8 = vec_xxpermdi(t3, t4, 3);
+
+ vec_xst(t5, 0, vecOffset);
+ vec_xst(t6, 0, vecOffset + 4);
+ vec_xst(t7, 0, vecOffset + 8);
+ vec_xst(t8, 0, vecOffset + 12);
+ }
+
+ inline void vector_permute_store_8(vector float * src, float * vecOffset) {
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
+ t1 = vec_mergeh(src[0], src[1]);
+ t2 = vec_mergeh(src[2], src[3]);
+ t3 = vec_mergeh(src[4], src[5]);
+ t4 = vec_mergeh(src[6], src[7]);
+
+ t5 = vec_xxpermdi(t1, t2, 0);
+ t6 = vec_xxpermdi(t3, t4, 0);
+ t7 = vec_xxpermdi(t1, t2, 3);
+ t8 = vec_xxpermdi(t3, t4, 3);
+
+ vec_xst(t5, 0, vecOffset);
+ vec_xst(t6, 0, vecOffset + 4);
+ vec_xst(t7, 0, vecOffset + 8);
+ vec_xst(t8, 0, vecOffset + 12);
+
+ t1 = vec_mergel(src[0], src[1]);
+ t2 = vec_mergel(src[2], src[3]);
+ t3 = vec_mergel(src[4], src[5]);
+ t4 = vec_mergel(src[6], src[7]);
+
+ t5 = vec_xxpermdi(t1, t2, 0);
+ t6 = vec_xxpermdi(t3, t4, 0);
+ t7 = vec_xxpermdi(t1, t2, 3);
+ t8 = vec_xxpermdi(t3, t4, 3);
+
+ vec_xst(t5, 0, vecOffset + 16);
+ vec_xst(t6, 0, vecOffset + 20);
+ vec_xst(t7, 0, vecOffset + 24);
+ vec_xst(t8, 0, vecOffset + 28);
+ }
+
+ void packTranspose(const float * a, int64_t lda, int rows, int cols, float * vec) {
int64_t i, j;
float * aoffsets[8];
- float *aoffset = NULL, *boffset = NULL;
+ float * aoffset = NULL, * boffset = NULL;
__vector_pair arr[8];
vector float c[8][2] = {0};
vector float c1[8] = {0};
vector float c2[8] = {0};
- aoffset = const_cast(a);
+ aoffset = const_cast(a);
boffset = vec;
j = (rows >> 3);
if (j > 0) {
-
do {
aoffsets[0] = aoffset;
- for (int it = 1; it< 8; it++)
+ for (int it = 1; it < 8; it++)
aoffsets[it] = aoffsets[it-1] + lda;
aoffset += 8 * lda;
i = (cols >> 3);
if (i > 0) {
do {
- for (int it = 0; it< 8; it++) {
+ for (int it = 0; it < 8; it++) {
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
c1[it] = c[it][0];
@@ -2264,11 +2287,14 @@ class tinyBLAS_PPC {
}
vector_permute_store_8(c1, boffset);
- vector_permute_store_8(c2, boffset+32);
- for (int it = 0; it < 4; it++)
- aoffsets[it] = aoffsets[it] + 8*lda;
+ vector_permute_store_8(c2, boffset + 32);
boffset += 64;
i--;
+ if (i > 0) {
+ for (int it = 0; it < 8; it++) {
+ aoffsets[it] = aoffsets[it] + 8;
+ }
+ }
} while(i > 0);
}
if (cols & 4) {
@@ -2295,9 +2321,9 @@ class tinyBLAS_PPC {
c2[it] = c[it][1];
}
vector_permute_store_4(c1, boffset);
- vector_permute_store_4(c2, boffset+16);
+ vector_permute_store_4(c2, boffset + 16);
for (int it = 0; it < 4; it++)
- aoffsets[it] += 8*lda;
+ aoffsets[it] += 8 * lda;
boffset += 32;
i--;
} while(i > 0);
@@ -2325,15 +2351,15 @@ class tinyBLAS_PPC {
vec_t vec_A[4], vec_B[4], vec_C[4];
acc_t acc_0;
__builtin_mma_xxsetaccz(&acc_0);
- for (int l = 0; l < k; l+=4) {
- packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
- packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+ for (int l = 0; l < k; l += 4) {
+ packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
+ packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
}
- SAVE_ACC(&acc_0, ii, jj);
+ save_acc(&acc_0, ii, jj);
}
void KERNEL_4x8(int64_t ii, int64_t jj) {
@@ -2341,9 +2367,9 @@ class tinyBLAS_PPC {
acc_t acc_0, acc_1;
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
- for (int64_t l = 0; l < k; l+=4) {
- packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
- packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
+ for (int64_t l = 0; l < k; l += 4) {
+ packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
+ packTranspose(B + (jj * ldb) + l, ldb, 8, 4, (float *)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
@@ -2353,8 +2379,8 @@ class tinyBLAS_PPC {
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
__builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
}
- SAVE_ACC(&acc_0, ii, jj);
- SAVE_ACC(&acc_1, ii, jj+4);
+ save_acc(&acc_0, ii, jj);
+ save_acc(&acc_1, ii, jj + 4);
}
void KERNEL_8x4(int64_t ii, int64_t jj) {
@@ -2362,9 +2388,9 @@ class tinyBLAS_PPC {
acc_t acc_0, acc_1;
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
- for (int64_t l = 0; l < k; l+=4) {
- packTranspose(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
- packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+ for (int64_t l = 0; l < k; l += 4) {
+ packTranspose(A + (ii * lda) + l, lda, 8, 4, (float *)vec_A);
+ packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
@@ -2374,8 +2400,8 @@ class tinyBLAS_PPC {
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
}
- SAVE_ACC(&acc_0, ii, jj);
- SAVE_ACC(&acc_1, ii+4, jj);
+ save_acc(&acc_0, ii, jj);
+ save_acc(&acc_1, ii + 4, jj);
}
void KERNEL_8x8(int64_t ii, int64_t jj) {
@@ -2386,19 +2412,96 @@ class tinyBLAS_PPC {
__builtin_mma_xxsetaccz(&acc_2);
__builtin_mma_xxsetaccz(&acc_3);
for (int l = 0; l < k; l+=8) {
- packTranspose(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
- packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
+ packTranspose(A + (ii * lda) + l, lda, 8, 8, (float *)vec_A);
+ packTranspose(B + (jj * ldb) + l, ldb, 8, 8, (float *)vec_B);
for(int x = 0; x < 16; x+=2) {
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
- __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
- __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
- __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x + 1]);
+ __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x + 1], vec_B[x]);
+ __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x + 1], vec_B[x + 1]);
+ }
+ }
+ save_acc(&acc_0, ii, jj);
+ save_acc(&acc_1, ii, jj + 4);
+ save_acc(&acc_2, ii + 4, jj);
+ save_acc(&acc_3, ii + 4, jj + 4);
+ }
+
+ inline void MMA_16x8(vec_t * vec_A0, vec_t * vec_A1, vec_t * vec_B, acc_t * acc) {
+ for (int x = 0; x < 16; x += 2) {
+ __builtin_mma_xvf32gerpp(&acc[0], vec_A0[x + 0], vec_B[x]);
+ __builtin_mma_xvf32gerpp(&acc[1], vec_A0[x + 0], vec_B[x + 1]);
+ __builtin_mma_xvf32gerpp(&acc[2], vec_A0[x + 1], vec_B[x]);
+ __builtin_mma_xvf32gerpp(&acc[3], vec_A0[x + 1], vec_B[x + 1]);
+ __builtin_mma_xvf32gerpp(&acc[4], vec_A1[x + 0], vec_B[x]);
+ __builtin_mma_xvf32gerpp(&acc[5], vec_A1[x + 0], vec_B[x + 1]);
+ __builtin_mma_xvf32gerpp(&acc[6], vec_A1[x + 1], vec_B[x]);
+ __builtin_mma_xvf32gerpp(&acc[7], vec_A1[x + 1], vec_B[x + 1]);
+ }
+ }
+
+ void KERNEL(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) {
+ for (int64_t i = 0; i < mc; i += 16) {
+ int A_base_addr = (mc / 8) * (i / 8) * 16;
+ for (int64_t j = 0; j < nc; j += 8) {
+ int B_base_addr = (nc / 8) * (j / 8) * 16;
+ acc_t acc[8];
+ vec_t A0_block[16]; vec_t A1_block[16];
+ for (int x = 0; x < 8; x++)
+ __builtin_mma_xxsetaccz(&acc[x]);
+ for (int64_t l = 0; l < kc; l += 8) {
+ int A0_block_idx = A_base_addr + (l / 8) * 16;
+ int A1_block_idx = A0_block_idx + (mc / 8) * 16;
+ int B_block_idx = B_base_addr + (l / 8) * 16;
+ vec_t* A0_block = &vec_A[A0_block_idx];
+ vec_t* A1_block = &vec_A[A1_block_idx];
+ vec_t* B_block = &vec_B[B_block_idx];
+ MMA_16x8(A0_block, A1_block, B_block, acc);
+ }
+ if (kk == 0) {
+ save_acc(&acc[0], ii + i, jj + j);
+ save_acc(&acc[1], ii + i, jj + j + 4);
+ save_acc(&acc[2], ii + i + 4, jj + j);
+ save_acc(&acc[3], ii + i + 4, jj + j + 4);
+ save_acc(&acc[4], ii + i + 8, jj + j);
+ save_acc(&acc[5], ii + i + 8, jj + j + 4);
+ save_acc(&acc[6], ii + i + 12, jj + j);
+ save_acc(&acc[7], ii + i + 12, jj + j + 4);
+ } else {
+ add_save_acc(&acc[0], ii + i, jj + j);
+ add_save_acc(&acc[1], ii + i, jj + j + 4);
+ add_save_acc(&acc[2], ii + i + 4, jj + j);
+ add_save_acc(&acc[3], ii + i + 4, jj + j + 4);
+ add_save_acc(&acc[4], ii + i + 8, jj + j);
+ add_save_acc(&acc[5], ii + i + 8, jj + j + 4);
+ add_save_acc(&acc[6], ii + i + 12, jj + j);
+ add_save_acc(&acc[7], ii + i + 12, jj + j + 4);
+ }
+ }
+ }
+ }
+
+ void matmul_tiled(int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) {
+ int64_t ytiles = m / mc;
+ int64_t xtiles = n / nc;
+ int64_t tiles = xtiles * ytiles;
+ int64_t duty = (tiles + nth - 1) / nth;
+ int64_t start = duty * ith;
+ int64_t end = start + duty;
+ if (end > tiles) {
+ end = tiles;
+ }
+ for (int64_t job = start; job < end; ++job) {
+ int64_t ii = (job / xtiles) * mc;
+ int64_t jj = (job % xtiles) * nc;
+ for (int64_t kk = 0; kk < k; kk += kc) {
+ vec_t A_pack[kc * mc / 4];
+ vec_t B_pack[kc * nc / 4];
+ packTranspose(A + (ii * lda) + kk, lda, kc, mc, (float *)A_pack);
+ packTranspose(B + (jj * ldb) + kk, ldb, kc, nc, (float *)B_pack);
+ KERNEL(ii, jj, mc, nc, kc, A_pack, B_pack, kk);
}
}
- SAVE_ACC(&acc_0, ii, jj);
- SAVE_ACC(&acc_1, ii, jj+4);
- SAVE_ACC(&acc_2, ii+4, jj);
- SAVE_ACC(&acc_3, ii+4, jj+4);
}
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
@@ -2406,35 +2509,35 @@ class tinyBLAS_PPC {
int n_rem = MIN(n - n0, 8);
int mc = 0, nc = 0;
if (m_rem >= 8 && n_rem >= 8) {
- mc = 8;
- nc = 8;
- gemm<8, 8>(m0, m, n0, n);
+ mc = 8;
+ nc = 8;
+ gemm<8, 8>(m0, m, n0, n);
} else if (m_rem >= 4 && n_rem >= 8) {
- mc = 4;
- nc = 8;
- gemm<4, 8>(m0, m, n0, n);
+ mc = 4;
+ nc = 8;
+ gemm<4, 8>(m0, m, n0, n);
} else if (m_rem >= 8 && n_rem >= 4) {
- mc = 8;
- nc = 4;
- gemm<8, 4>(m0, m, n0, n);
+ mc = 8;
+ nc = 4;
+ gemm<8, 4>(m0, m, n0, n);
} else if (m_rem >= 4 && n_rem >= 4) {
- mc = 4;
- nc = 4;
- gemm<4, 4>(m0, m, n0, n);
+ mc = 4;
+ nc = 4;
+ gemm<4, 4>(m0, m, n0, n);
} else {
mc = (m_rem >= 4) ? 4 : m_rem;
nc = (n_rem >= 4) ? 4 : n_rem;
if (mc == 0 || nc == 0)
- return;
+ return;
gemm_small(m0, m, n0, n, mc, nc);
}
int64_t mp = m0 + ((m - m0) / mc) * mc;
int64_t np = n0 + ((n - n0) / nc) * nc;
mnpack(mp, m, n0, np);
mnpack(m0, m, np, n);
- }
+ }
- void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
int64_t ytiles = (m - m0) / RM;
int64_t xtiles = (n - n0) / RN;
int64_t tiles = xtiles * ytiles;
@@ -2449,30 +2552,30 @@ class tinyBLAS_PPC {
vec_t vec_C[4];
acc_t acc_0;
__builtin_mma_xxsetaccz(&acc_0);
- vec_t vec_A[4] {0}, vec_B[4] = {0};
- for (int l=0; l(A+(ii)*lda+l);
- packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
+ float * a = const_cast(A + (ii) * lda + l);
+ packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
vec_A[0] = (vec_t)vec_xl(0,a);
- vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
- vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
- vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
+ vec_A[1] = (vec_t)vec_splats(*((float *)&vec_A+1));
+ vec_A[2] = (vec_t)vec_splats(*((float *)&vec_A+2));
+ vec_A[3] = (vec_t)vec_splats(*((float *)&vec_A+3));
} else if (RN == 1) {
- packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
- float* b = const_cast(B+(jj)*ldb+l);
+ packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
+ float * b = const_cast(B + (jj) * ldb + l);
vec_B[0] = (vec_t)vec_xl(0,b);
- vec_B[1] = (vec_t)vec_splats(*((float*)&vec_B+1));
- vec_B[2] = (vec_t)vec_splats(*((float*)&vec_B+2));
- vec_B[3] = (vec_t)vec_splats(*((float*)&vec_B+3));
+ vec_B[1] = (vec_t)vec_splats(*((float *)&vec_B+1));
+ vec_B[2] = (vec_t)vec_splats(*((float *)&vec_B+2));
+ vec_B[3] = (vec_t)vec_splats(*((float *)&vec_B+3));
} else {
- packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
- packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
+ packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
+ packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
}
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
@@ -2482,12 +2585,27 @@ class tinyBLAS_PPC {
__builtin_mma_disassemble_acc(vec_C, &acc_0);
for (int I = 0; I < RM; I++) {
for (int J = 0; J < RN; J++) {
- *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
+ *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
}
}
}
}
+ template
+ inline void kernel(int64_t ii, int64_t jj) {
+ if constexpr(RM == 4 && RN == 4) {
+ KERNEL_4x4(ii, jj);
+ } else if constexpr(RM == 4 && RN == 8) {
+ KERNEL_4x8(ii, jj);
+ } else if constexpr(RM == 8 && RN == 4) {
+ KERNEL_8x4(ii, jj);
+ } else if constexpr(RM == 8 && RN == 8) {
+ KERNEL_8x8(ii, jj);
+ } else {
+ static_assert(false, "RN/RM values not supported");
+ }
+ }
+
template
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t ytiles = (m - m0) / RM;
@@ -2496,27 +2614,18 @@ class tinyBLAS_PPC {
int64_t duty = (tiles + nth - 1) / nth;
int64_t start = duty * ith;
int64_t end = start + duty;
- if (RM == 4 && RN == 4) {
- kernel = &tinyBLAS_PPC::KERNEL_4x4;
- } else if (RM == 4 && RN == 8) {
- kernel = &tinyBLAS_PPC::KERNEL_4x8;
- } else if (RM == 8 && RN == 4) {
- kernel = &tinyBLAS_PPC::KERNEL_8x4;
- } else if (RM == 8 && RN == 8) {
- kernel = &tinyBLAS_PPC::KERNEL_8x8;
- }
if (end > tiles)
end = tiles;
for (int64_t job = start; job < end; ++job) {
int64_t ii = m0 + job / xtiles * RM;
int64_t jj = n0 + job % xtiles * RN;
- (this->*kernel)(ii, jj);
+ kernel(ii, jj);
}
}
- const float *const A;
- const float *const B;
- float *C;
+ const float * const A;
+ const float * const B;
+ float * C;
const int64_t k;
const int64_t lda;
const int64_t ldb;
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index b72a2556a5..8c1f794885 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -7207,6 +7207,148 @@ void ggml_compute_forward_conv_2d(
ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
}
+// ggml_compute_forward_conv_3d
+
+static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
+ const ggml_tensor * kernel,
+ const ggml_tensor * src,
+ ggml_tensor * dst,
+ ggml_type kernel_type) {
+
+ GGML_ASSERT(ggml_is_contiguous(kernel));
+ GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
+ GGML_ASSERT(kernel->type == kernel_type);
+
+ const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
+
+ const int32_t s0 = dst->op_params[0];
+ const int32_t s1 = dst->op_params[1];
+ const int32_t s2 = dst->op_params[2];
+ const int32_t p0 = dst->op_params[3];
+ const int32_t p1 = dst->op_params[4];
+ const int32_t p2 = dst->op_params[5];
+ const int32_t d0 = dst->op_params[6];
+ const int32_t d1 = dst->op_params[7];
+ const int32_t d2 = dst->op_params[8];
+ const int32_t c = dst->op_params[9];
+ const int32_t n = dst->op_params[10];
+ const int32_t oc = dst->op_params[11];
+
+ const int64_t src_w = src->ne[0];
+ const int64_t src_h = src->ne[1];
+ const int64_t src_d = src->ne[2];
+ const int64_t knl_w = kernel->ne[0];
+ const int64_t knl_h = kernel->ne[1];
+ const int64_t knl_d = kernel->ne[2];
+ const int64_t dst_w = dst->ne[0];
+ const int64_t dst_h = dst->ne[1];
+ const int64_t dst_d = dst->ne[2];
+
+ const float * src_data = (float *) src->data;
+ void * knl_data = kernel->data;
+ float * dst_data = (float *) dst->data;
+
+ const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
+ const int64_t knl_n_total = knl_n_per_channel * c;
+ const int64_t patch_total = n * dst_w * dst_h * dst_d;
+
+ const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float);
+ const int64_t batch_size = params->wsize / space_per_patch;
+ const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
+ const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
+
+ GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
+
+ void * tmp = params->wdata;
+
+ for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
+ const int64_t patch_start_batch = batch_i * patches_per_batch;
+ const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total);
+ const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
+
+ const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
+ const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
+ const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
+
+ for (int64_t p = patch_start; p < patch_end; ++p) {
+ const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
+ const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
+ const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
+ const int64_t dst_z = p_in_batch / (dst_w * dst_h);
+ const int64_t dst_y = p_in_depth / dst_w;
+ const int64_t dst_x = p_in_depth % dst_w;
+
+ char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
+
+ for (int64_t ic = 0; ic < c; ++ic) {
+ for (int64_t kz = 0; kz < knl_d; ++kz) {
+ for (int64_t ky = 0; ky < knl_h; ++ky) {
+ for (int64_t kx = 0; kx < knl_w; ++kx) {
+ const int64_t sz = dst_z * s2 + kz * d2 - p2;
+ const int64_t sy = dst_y * s1 + ky * d1 - p1;
+ const int64_t sx = dst_x * s0 + kx * d0 - p0;
+
+ int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
+
+ float src_val;
+ if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
+ src_val = 0.0f;
+ } else {
+ const int64_t cn_idx = batch_idx * c + ic;
+ const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
+ src_val = *src_ptr;
+ }
+
+ char * element_ptr = dst_row + dst_idx * traits->type_size;
+ if (kernel_type == GGML_TYPE_F32) {
+ *(float *)element_ptr = src_val;
+ } else if (kernel_type == GGML_TYPE_F16) {
+ *(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ ggml_barrier(params->threadpool);
+
+ float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
+ ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
+
+ ggml_barrier(params->threadpool);
+
+ const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
+ const int64_t permute_start = params->ith * permute_per_thread;
+ const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
+
+ for (int64_t i = permute_start; i < permute_end; ++i) {
+ const int64_t p = patch_start_batch + i;
+ const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
+ const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
+ const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
+ const int64_t dst_z = p_in_batch / (dst_w * dst_h);
+ const int64_t dst_y = p_in_depth / dst_w;
+ const int64_t dst_x = p_in_depth % dst_w;
+
+ for (int64_t ioc = 0; ioc < oc; ++ioc) {
+ const float value = gemm_output[i * oc + ioc];
+ const int64_t ocn_idx = batch_idx * oc + ioc;
+ float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
+ *dst_ptr = value;
+ }
+ }
+ }
+}
+
+void ggml_compute_forward_conv_3d(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
+}
+
// ggml_compute_forward_conv_transpose_2d
void ggml_compute_forward_conv_transpose_2d(
@@ -8861,8 +9003,7 @@ static void ggml_compute_forward_ssm_scan_f32(
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
- // allows optimizing the modulo since n_group should be a power of 2
- GGML_ASSERT((ng & -ng) == ng);
+ GGML_ASSERT(nh % ng == 0);
// heads per thread
const int dh = (nh + nth - 1)/nth;
@@ -8893,6 +9034,7 @@ static void ggml_compute_forward_ssm_scan_f32(
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
const float dA = expf(dt_soft_plus * A[h]);
+ const int g = h / (nh / ng); // repeat_interleave
// dim
for (int i1 = 0; i1 < nr; ++i1) {
@@ -8915,8 +9057,8 @@ static void ggml_compute_forward_ssm_scan_f32(
// TODO: maybe unroll more?
for (int j = 0; j < 1; j++) {
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
- GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
- GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
+ GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
+ GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
t0 = GGML_F32_VEC_MUL(t0, adA);
t1 = GGML_F32_VEC_MUL(t1, axdt);
@@ -8930,6 +9072,9 @@ static void ggml_compute_forward_ssm_scan_f32(
}
sumf = GGML_F32xt_REDUCE_ONE(sum);
+ #elif defined(__riscv_v_intrinsic)
+ // todo: RVV implementation
+ const int np = 0;
#else
const int np = (nc & ~(GGML_F32_STEP - 1));
@@ -8945,8 +9090,8 @@ static void ggml_compute_forward_ssm_scan_f32(
for (int i = 0; i < np; i += GGML_F32_STEP) {
for (int j = 0; j < GGML_F32_ARR; j++) {
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
- ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
- az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
+ ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);
+ az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
@@ -8968,7 +9113,7 @@ static void ggml_compute_forward_ssm_scan_f32(
// d_state
for (int i0 = np; i0 < nc; ++i0) {
const int i = i0 + ii*nc;
- const int ig = i0 + (h & (ng - 1))*nc;
+ const int ig = i0 + g*nc;
// state = prev_state * dA + dB * x
const float state = (s0[i] * dA) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C)
@@ -8985,6 +9130,7 @@ static void ggml_compute_forward_ssm_scan_f32(
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
+ const int g = h / (nh / ng); // repeat_interleave
// dim
for (int i1 = 0; i1 < nr; ++i1) {
@@ -8999,8 +9145,8 @@ static void ggml_compute_forward_ssm_scan_f32(
// TODO: what happens when (d_state % svcntw()) != 0?
for (int64_t k = 0; k < nc; k += svcntw()) {
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
- svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
- svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
+ svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
+ svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
@@ -9020,7 +9166,7 @@ static void ggml_compute_forward_ssm_scan_f32(
// d_state
for (int i0 = 0; i0 < nc; ++i0) {
const int i = i0 + ii*nc;
- const int ig = i0 + (h & (ng - 1))*nc;
+ const int ig = i0 + g*nc;
// state = prev_state * dA + dB * x
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C)
@@ -9881,8 +10027,8 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
int64_t h_stride_2d = head_size * head_size;
#if defined(GGML_SIMD)
- #if defined(__ARM_FEATURE_SVE)
- // scalar Route to scalar implementation //TODO: Write SVE code
+ #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
+ // scalar Route to scalar implementation //TODO: Write SVE code and RVV code
for (int64_t t = 0; t < T; t++) {
int64_t t_offset = t * t_stride;
int64_t state_offset = head_size * C * (t / (T / n_seqs));
diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h
index 82ea79eaa5..d0ea83843b 100644
--- a/ggml/src/ggml-cpu/ops.h
+++ b/ggml/src/ggml-cpu/ops.h
@@ -70,6 +70,7 @@ void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * p
void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h
index b4ad68c9fd..f71ce58079 100644
--- a/ggml/src/ggml-cpu/simd-mappings.h
+++ b/ggml/src/ggml-cpu/simd-mappings.h
@@ -18,6 +18,10 @@
#include
#endif
+#if defined(__riscv_v_intrinsic)
+#include
+#endif
+
#ifdef __cplusplus
extern "C" {
#endif
@@ -94,24 +98,15 @@ extern "C" {
}
#elif defined(__riscv) && defined(__riscv_zfhmin)
static inline float riscv_compute_fp16_to_fp32(ggml_fp16_t h) {
- float f;
- __asm__(
- "fmv.h.x %[f], %[h]\n\t"
- "fcvt.s.h %[f], %[f]"
- : [f] "=&f" (f)
- : [h] "r" (h)
- );
- return f;
+ _Float16 hf;
+ memcpy(&hf, &h, sizeof(ggml_fp16_t));
+ return hf;
}
static inline ggml_fp16_t riscv_compute_fp32_to_fp16(float f) {
ggml_fp16_t res;
- __asm__(
- "fcvt.h.s %[f], %[f]\n\t"
- "fmv.x.h %[h], %[f]"
- : [h] "=&r" (res)
- : [f] "f" (f)
- );
+ _Float16 hf = (_Float16)f;
+ memcpy(&res, &hf, sizeof(ggml_fp16_t));
return res;
}
@@ -1170,6 +1165,36 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
#define GGML_F16_VEC_MUL GGML_F32x4_MUL
#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
+#elif defined(__riscv_v_intrinsic)
+
+// compatible with vlen >= 128
+
+#define GGML_SIMD
+
+// F32
+
+#define GGML_F32_STEP 16
+#define GGML_F32_EPR 4
+
+#define GGML_F32x4 vfloat32m1_t
+#define GGML_F32x4_ZERO __riscv_vfmv_v_f_f32m1(0.0f, GGML_F32_EPR)
+#define GGML_F32x4_SET1(x) __riscv_vfmv_v_f_f32m1(x, GGML_F32_EPR)
+#define GGML_F32x4_LOAD(x) __riscv_vle32_v_f32m1(x, GGML_F32_EPR)
+#define GGML_F32x4_STORE(b, v) __riscv_vse32_v_f32m1(b, v, GGML_F32_EPR)
+#define GGML_F32x4_FMA(a, b, c) __riscv_vfmacc_vv_f32m1(a, b, c, GGML_F32_EPR)
+#define GGML_F32x4_ADD(a, b) __riscv_vfadd_vv_f32m1(a, b, GGML_F32_EPR)
+#define GGML_F32x4_MUL(a, b) __riscv_vfmul_vv_f32m1(a, b, GGML_F32_EPR)
+
+#define GGML_F32_VEC GGML_F32x4
+#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
#endif
// GGML_F32_ARR / GGML_F16_ARR
diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp
index 07b377bdd8..d8ec3b81d2 100644
--- a/ggml/src/ggml-cpu/vec.cpp
+++ b/ggml/src/ggml-cpu/vec.cpp
@@ -84,6 +84,16 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
}
// reduce sum1,sum2 to sum1
GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
+ #elif defined(__riscv_v_intrinsic)
+ vfloat32m1_t vsum = __riscv_vfmv_v_f_f32m1(0.0f, 1);
+ for (int i = 0, avl; i < n; i += avl) {
+ avl = __riscv_vsetvl_e32m8(n - i);
+ vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
+ vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
+ vfloat32m8_t prod = __riscv_vfmul_vv_f32m8(ax, ay, avl);
+ vsum = __riscv_vfredusum_vs_f32m8_f32m1(prod, vsum, avl);
+ }
+ sumf += __riscv_vfmv_f_s_f32m1_f32(vsum);
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@@ -197,7 +207,7 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
ggml_float sumf = 0.0;
-#if defined(GGML_SIMD)
+#if defined(GGML_SIMD) && !defined(__riscv_v_intrinsic)
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
@@ -325,6 +335,15 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float
vst1q_f32(y + i, val);
sum += (ggml_float)vaddvq_f32(val);
}
+#elif defined(__riscv_v_intrinsic)
+ vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
+ for (int avl; i < n; i += avl) {
+ avl = __riscv_vsetvl_e32m2(n - i);
+ vfloat32m2_t val = ggml_v_expf_m2(__riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], avl), max, avl), avl);
+ __riscv_vse32_v_f32m2(&y[i], val, avl);
+ vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, avl);
+ }
+ return (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
#endif
for (; i < n; ++i) {
float val = expf(x[i] - max);
diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h
index 2250d93cb0..8ccf340d47 100644
--- a/ggml/src/ggml-cpu/vec.h
+++ b/ggml/src/ggml-cpu/vec.h
@@ -119,6 +119,14 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
}
#if defined(GGML_SIMD)
+#if defined(__riscv_v_intrinsic)
+ // todo: RVV impl
+ for (int i = 0; i < n; ++i) {
+ for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+ sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
+ }
+ }
+#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
@@ -149,6 +157,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
}
+#endif
#else
for (int i = 0; i < n; ++i) {
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
@@ -243,6 +252,14 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
svst1_f32(pg, y + np2, ay1);
}
+ #elif defined(__riscv_v_intrinsic)
+ for (int i = 0, avl; i < n; i += avl) {
+ avl = __riscv_vsetvl_e32m8(n - i);
+ vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
+ vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
+ vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, v, ay, avl);
+ __riscv_vse32_v_f32m8(&y[i], ny, avl);
+ }
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@@ -276,6 +293,13 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
#if defined(GGML_SIMD)
+#if defined(__riscv_v_intrinsic)
+ // todo: RVV impl
+ // scalar
+ for (int i = 0; i < n; ++i) {
+ y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
+ }
+#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
@@ -297,6 +321,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
+#endif
#else
// scalar
for (int i = 0; i < n; ++i) {
@@ -324,6 +349,16 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int
y[i] += x[k][i]*v[k][0];
}
}
+ #elif defined(__riscv_v_intrinsic)
+ for (int i = 0, avl; i < n; i += avl) {
+ avl = __riscv_vsetvl_e32m8(n - i);
+ vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; k++) {
+ vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[k][i], avl);
+ ay = __riscv_vfmadd_vf_f32m8(ax, v[k][0], ay, avl);
+ }
+ __riscv_vse32_v_f32m8(&y[i], ay, avl);
+ }
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@@ -375,6 +410,14 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, co
for (int i = 0; i < n; ++i) {
y[i] = x[i]*s + b;
}
+ #elif defined(__riscv_v_intrinsic)
+ for (int i = 0, avl; i < n; i += avl) {
+ avl = __riscv_vsetvl_e32m8(n - i);
+ vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
+ vfloat32m8_t vb = __riscv_vfmv_v_f_f32m8(b, avl);
+ vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, s, vb, avl);
+ __riscv_vse32_v_f32m8(&y[i], ny, avl);
+ }
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@@ -436,6 +479,13 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
ay1 = svmul_f32_m(pg, ay1, vx);
svst1_f32(pg, y + np, ay1);
}
+ #elif defined(__riscv_v_intrinsic)
+ for (int i = 0, avl; i < n; i += avl) {
+ avl = __riscv_vsetvl_e32m8(n - i);
+ vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
+ vfloat32m8_t ny = __riscv_vfmul_vf_f32m8(ay, v, avl);
+ __riscv_vse32_v_f32m8(&y[i], ny, avl);
+ }
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@@ -467,6 +517,13 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
#if defined(GGML_SIMD)
+#if defined(__riscv_v_intrinsic)
+ // todo: RVV impl
+ // scalar
+ for (int i = 0; i < n; ++i) {
+ y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
+ }
+#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
@@ -486,6 +543,7 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
}
+#endif
#else
// scalar
for (int i = 0; i < n; ++i) {
@@ -928,7 +986,51 @@ inline static __m128 ggml_v_silu(__m128 x) {
return _mm_div_ps(x, one_plus_exp_neg_x);
}
-#endif // __ARM_NEON / __AVX2__ / __SSE2__
+#elif defined(__riscv_v_intrinsic)
+
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+inline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) {
+ const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2(0x1.8p23f, vl);
+#ifdef __riscv_xtheadvector
+ // workaround for compiler bug (gcc 14.3.0: Error: unrecognized opcode `th.vmv1r.v v2,v4')
+ vfloat32m2_t z = __riscv_vfadd_vf_f32m2(r, 0.0f, vl);
+ z = __riscv_vfmacc_vf_f32m2(z, 0x1.715476p+0f, x, vl);
+#else
+ const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2(r, 0x1.715476p+0f, x, vl);
+#endif
+ const vfloat32m2_t n = __riscv_vfsub_vv_f32m2(z, r, vl);
+ const vfloat32m2_t b = __riscv_vfnmsac_vf_f32m2(__riscv_vfnmsac_vf_f32m2(x, 0x1.62e4p-1f, n, vl),
+ 0x1.7f7d1cp-20f, n, vl);
+ const vuint32m2_t e = __riscv_vsll_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(z), 23, vl);
+ const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(e, 0x3f800000, vl)); // 1.0f
+ const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 126.0f, vl);
+ const vfloat32m2_t u = __riscv_vfmul_vv_f32m2(b, b, vl);
+ const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2(
+ __riscv_vfmul_vf_f32m2(b, 0x1.ffffecp-1f, vl),
+ __riscv_vfmacc_vv_f32m2(
+ __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.fffdb6p-2f, vl), 0x1.555e66p-3f, b, vl),
+ __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.573e2ep-5f, vl), 0x1.0e4020p-7f, b, vl),
+ u, vl), u, vl);
+ if (!__riscv_vcpop_m_b16(c, vl))
+ return __riscv_vfmacc_vv_f32m2(k, j, k, vl);
+ const vbool16_t dm = __riscv_vmfle_vf_f32m2_b16(n, 0.0f, vl);
+ const vuint32m2_t d = __riscv_vmerge_vxm_u32m2(__riscv_vmv_v_x_u32m2(0, vl), 0x82000000, dm, vl);
+ const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(d, 0x7f000000, vl));
+ const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsub_vv_u32m2(e, d, vl));
+ const vfloat32m2_t r1 = __riscv_vmerge_vvm_f32m2(
+ __riscv_vfmacc_vv_f32m2(k, k, j, vl),
+ __riscv_vfmul_vv_f32m2(__riscv_vfmacc_vv_f32m2(s2, s2, j, vl), s1, vl),
+ c, vl);
+ return __riscv_vmerge_vvm_f32m2(
+ r1, __riscv_vfmul_vv_f32m2(s1, s1, vl),
+ __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 192.0f, vl),
+ vl);
+}
+
+#endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic
inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
for (int i = 0; i < n; ++i) {
diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt
index ea824965aa..d3dfc7807d 100644
--- a/ggml/src/ggml-cuda/CMakeLists.txt
+++ b/ggml/src/ggml-cuda/CMakeLists.txt
@@ -94,7 +94,11 @@ if (CUDAToolkit_FOUND)
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
else ()
- target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static)
+ if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1")
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
+ else()
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static)
+ endif()
endif()
else()
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)
diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu
index e1fbf0e136..1c76566344 100644
--- a/ggml/src/ggml-cuda/binbcast.cu
+++ b/ggml/src/ggml-cuda/binbcast.cu
@@ -1,5 +1,6 @@
#include "binbcast.cuh"
#include
+#include
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
return b;
@@ -22,13 +23,16 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
return a / b;
}
-template
+
+
+template
static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
- int ne0, int ne1, int ne2, int ne3,
- int ne10, int ne11, int ne12, int ne13,
- /*int s0, */ int s1, int s2, int s3,
- /*int s00,*/ int s01, int s02, int s03,
- /*int s10,*/ int s11, int s12, int s13) {
+ const int ne0, const int ne1, const int ne2, const int ne3,
+ const int ne10, const int ne11, const int ne12, const int ne13,
+ /*int s0, */ const int s1, const int s2, const int s3,
+ /*int s00,*/ const int s01, const int s02, const int s03,
+ /*int s10,*/ const int s11, const int s12, const int s13,
+ src1_ptrs... src1s) {
const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
@@ -46,24 +50,31 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
- const src0_t * src0_row = src0 + i_src0;
- const src1_t * src1_row = src1 + i_src1;
+ const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
const int i10 = i0 % ne10;
- dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+
+ float result = src0_row ? (float) src0_row[i0] : 0.0f;
+ if constexpr (sizeof...(src1_ptrs) > 0) {
+ result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+ } else {
+ result = bin_op(result, (float)src1[i_src1 + i10]);
+ }
+
+ dst_row[i0] = (dst_t) result;
}
}
-template
-static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
- int ne0, int ne1, int ne2, int ne3,
- int ne10, int ne11, int ne12, int ne13,
- /*int s0, */ int s1, int s2, int s3,
- /*int s00,*/ int s01, int s02, int s03,
- /*int s10,*/ int s11, int s12, int s13) {
-
+template
+static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
+ const int ne0, const int ne1, const int ne2,const int ne3,
+ const int ne10, const int ne11, const int ne12, const int ne13,
+ /*int s0, */ const int s1, const int s2, const int s3,
+ /*int s00,*/ const int s01, const int s02, const int s03,
+ /*int s10,*/ const int s11, const int s12, const int s13,
+ src1_ptrs ... src1s) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
const int i3 = i/(ne2*ne1*ne0);
@@ -83,12 +94,190 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
- const src0_t * src0_row = src0 + i_src0;
- const src1_t * src1_row = src1 + i_src1;
+ const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;
const int i10 = i0 % ne10;
- dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+
+ float result = src0_row ? (float) src0_row[i0] : 0.0f;
+ if constexpr (sizeof...(src1_ptrs) > 0) {
+ result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+ } else {
+ result = bin_op(result, (float)src1[i_src1 + i10]);
+ }
+
+ dst_row[i0] = (dst_t) result;
+}
+
+template
+static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
+ cudaStream_t stream, std::index_sequence) {
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ int nr0 = ne10 / ne0;
+ int nr1 = ne11 / ne1;
+ int nr2 = ne12 / ne2;
+ int nr3 = ne13 / ne3;
+
+ int nr[4] = { nr0, nr1, nr2, nr3 };
+
+ int64_t cne[] = { ne0, ne1, ne2, ne3 };
+ int64_t cne0[] = { ne00, ne01, ne02, ne03 };
+ int64_t cne1[] = { ne10, ne11, ne12, ne13 };
+
+ size_t cnb[] = { nb0, nb1, nb2, nb3 };
+ size_t cnb0[] = { nb00, nb01, nb02, nb03 };
+ size_t cnb1[] = { nb10, nb11, nb12, nb13 };
+
+ auto collapse = [](int64_t cne[]) {
+ cne[0] *= cne[1];
+ cne[1] = cne[2];
+ cne[2] = cne[3];
+ cne[3] = 1;
+ };
+
+ auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
+ cnb[1] *= cne[1];
+ cnb[2] *= cne[2];
+ cnb[3] *= cne[3];
+ };
+
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
+ for (int i = 0; i < 4; i++) {
+ if (nr[i] != 1) {
+ break;
+ }
+ if (i > 0) {
+ collapse_nb(cnb, cne);
+ collapse_nb(cnb0, cne0);
+ collapse_nb(cnb1, cne1);
+ collapse(cne);
+ collapse(cne0);
+ collapse(cne1);
+ }
+ }
+ }
+
+ {
+ int64_t ne0 = cne[0];
+ int64_t ne1 = cne[1];
+ int64_t ne2 = cne[2];
+ int64_t ne3 = cne[3];
+
+ //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
+ //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
+ //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
+ //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
+
+ int64_t ne10 = cne1[0];
+ int64_t ne11 = cne1[1];
+ int64_t ne12 = cne1[2];
+ int64_t ne13 = cne1[3];
+
+ size_t nb0 = cnb[0];
+ size_t nb1 = cnb[1];
+ size_t nb2 = cnb[2];
+ size_t nb3 = cnb[3];
+
+ size_t nb00 = cnb0[0];
+ size_t nb01 = cnb0[1];
+ size_t nb02 = cnb0[2];
+ size_t nb03 = cnb0[3];
+
+ size_t nb10 = cnb1[0];
+ size_t nb11 = cnb1[1];
+ size_t nb12 = cnb1[2];
+ size_t nb13 = cnb1[3];
+
+ size_t s0 = nb0 / sizeof(dst_t);
+ size_t s1 = nb1 / sizeof(dst_t);
+ size_t s2 = nb2 / sizeof(dst_t);
+ size_t s3 = nb3 / sizeof(dst_t);
+
+ size_t s10 = nb10 / sizeof(src1_t);
+ size_t s11 = nb11 / sizeof(src1_t);
+ size_t s12 = nb12 / sizeof(src1_t);
+ size_t s13 = nb13 / sizeof(src1_t);
+
+ size_t s00 = nb00 / sizeof(src0_t);
+ size_t s01 = nb01 / sizeof(src0_t);
+ size_t s02 = nb02 / sizeof(src0_t);
+ size_t s03 = nb03 / sizeof(src0_t);
+
+ GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
+ GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
+ GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
+ GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
+
+ GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
+ GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
+ GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
+ GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
+
+ GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
+ GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
+ GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
+ GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
+
+ GGML_ASSERT(s0 == 1);
+ GGML_ASSERT(s00 == 1);
+ GGML_ASSERT(s10 == 1);
+
+ const int block_size = 128;
+
+ int64_t hne0 = std::max(ne0 / 2LL, 1LL);
+
+ dim3 block_dims;
+ block_dims.x = std::min(hne0, block_size);
+ block_dims.y = std::min(ne1, block_size / block_dims.x);
+ block_dims.z = std::min(std::min(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
+
+ dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x,
+ (ne1 + block_dims.y - 1) / block_dims.y,
+ (ne2 * ne3 + block_dims.z - 1) / block_dims.z);
+
+ if (block_nums.z > 65535) {
+ int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
+ if constexpr (sizeof...(I) > 0) {
+ k_bin_bcast_unravel
+ <<>>(src0_dd, src1_dd, dst_dd,
+ ne0, ne1, ne2, ne3,
+ ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12,s13,
+ (const src1_t *) dst->src[I + 1]->data...);
+ } else {
+ k_bin_bcast_unravel
+ <<>>(src0_dd, src1_dd, dst_dd,
+ ne0, ne1, ne2, ne3,
+ ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12,s13);
+ }
+ } else {
+ if constexpr (sizeof...(I) > 0) {
+ k_bin_bcast
+ <<>>(src0_dd, src1_dd, dst_dd,
+ ne0, ne1, ne2, ne3,
+ ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12,s13,
+ (const src1_t *) dst->src[I + 1]->data...);
+ } else {
+ k_bin_bcast
+ <<>>(src0_dd, src1_dd, dst_dd,
+ ne0, ne1, ne2, ne3,
+ ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12,s13);
+ }
+ }
+ }
}
template
@@ -120,160 +309,14 @@ static __global__ void k_repeat_back(
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
}
-template
+template
struct bin_bcast_cuda {
template
void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
cudaStream_t stream) {
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- int nr0 = ne10/ne0;
- int nr1 = ne11/ne1;
- int nr2 = ne12/ne2;
- int nr3 = ne13/ne3;
-
- int nr[4] = { nr0, nr1, nr2, nr3 };
-
- // collapse dimensions until first broadcast dimension
- int64_t cne[] = {ne0, ne1, ne2, ne3};
- int64_t cne0[] = {ne00, ne01, ne02, ne03};
- int64_t cne1[] = {ne10, ne11, ne12, ne13};
-
- size_t cnb[] = {nb0, nb1, nb2, nb3};
- size_t cnb0[] = {nb00, nb01, nb02, nb03};
- size_t cnb1[] = {nb10, nb11, nb12, nb13};
-
- auto collapse = [](int64_t cne[]) {
- cne[0] *= cne[1];
- cne[1] = cne[2];
- cne[2] = cne[3];
- cne[3] = 1;
- };
-
- auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
- cnb[1] *= cne[1];
- cnb[2] *= cne[2];
- cnb[3] *= cne[3];
- };
-
- if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
- for (int i = 0; i < 4; i++) {
- if (nr[i] != 1) {
- break;
- }
- if (i > 0) {
- collapse_nb(cnb, cne);
- collapse_nb(cnb0, cne0);
- collapse_nb(cnb1, cne1);
- collapse(cne);
- collapse(cne0);
- collapse(cne1);
- }
- }
- }
-
- {
- int64_t ne0 = cne[0];
- int64_t ne1 = cne[1];
- int64_t ne2 = cne[2];
- int64_t ne3 = cne[3];
-
- //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
- //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
- //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
- //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
-
- int64_t ne10 = cne1[0];
- int64_t ne11 = cne1[1];
- int64_t ne12 = cne1[2];
- int64_t ne13 = cne1[3];
-
- size_t nb0 = cnb[0];
- size_t nb1 = cnb[1];
- size_t nb2 = cnb[2];
- size_t nb3 = cnb[3];
-
- size_t nb00 = cnb0[0];
- size_t nb01 = cnb0[1];
- size_t nb02 = cnb0[2];
- size_t nb03 = cnb0[3];
-
- size_t nb10 = cnb1[0];
- size_t nb11 = cnb1[1];
- size_t nb12 = cnb1[2];
- size_t nb13 = cnb1[3];
-
- size_t s0 = nb0 / sizeof(dst_t);
- size_t s1 = nb1 / sizeof(dst_t);
- size_t s2 = nb2 / sizeof(dst_t);
- size_t s3 = nb3 / sizeof(dst_t);
-
- size_t s10 = nb10 / sizeof(src1_t);
- size_t s11 = nb11 / sizeof(src1_t);
- size_t s12 = nb12 / sizeof(src1_t);
- size_t s13 = nb13 / sizeof(src1_t);
-
- size_t s00 = nb00 / sizeof(src0_t);
- size_t s01 = nb01 / sizeof(src0_t);
- size_t s02 = nb02 / sizeof(src0_t);
- size_t s03 = nb03 / sizeof(src0_t);
-
- GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
- GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
- GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
- GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
-
- GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
- GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
- GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
- GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
-
- GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
- GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
- GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
- GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
-
- GGML_ASSERT(s0 == 1);
- GGML_ASSERT(s00 == 1);
- GGML_ASSERT(s10 == 1);
-
- const int block_size = 128;
-
- int64_t hne0 = std::max(ne0/2LL, 1LL);
-
- dim3 block_dims;
- block_dims.x = std::min(hne0, block_size);
- block_dims.y = std::min(ne1, block_size / block_dims.x);
- block_dims.z = std::min(std::min(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U);
-
- dim3 block_nums(
- (hne0 + block_dims.x - 1) / block_dims.x,
- (ne1 + block_dims.y - 1) / block_dims.y,
- (ne2*ne3 + block_dims.z - 1) / block_dims.z
- );
-
- if (block_nums.z > 65535) {
- // this is the maximum number of blocks in z dimension, fallback to 1D grid kernel
- int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
- k_bin_bcast_unravel<<>>(
- src0_dd, src1_dd, dst_dd,
- ne0, ne1, ne2, ne3,
- ne10, ne11, ne12, ne13,
- /* s0, */ s1, s2, s3,
- /* s00, */ s01, s02, s03,
- /* s10, */ s11, s12, s13);
- } else {
- k_bin_bcast<<>>(
- src0_dd, src1_dd, dst_dd,
- ne0, ne1, ne2, ne3,
- ne10, ne11, ne12, ne13,
- /* s0, */ s1, s2, s3,
- /* s00, */ s01, s02, s03,
- /* s10, */ s11, s12, s13);
- }
- }
+ launch_bin_bcast_pack(
+ src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence{});
}
};
@@ -312,7 +355,7 @@ static void ggml_cuda_op_bin_bcast(
}
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- ggml_cuda_op_bin_bcast>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
+ ggml_cuda_op_bin_bcast>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
}
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -331,6 +374,68 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}
+template
+static void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ cudaStream_t stream = ctx.stream();
+
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ launch_bin_bcast_pack(src0, src1, dst,
+ (const float *) src0->data, (const float *) src1->data, (float *) dst->data,
+ stream, std::make_index_sequence{});
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+ launch_bin_bcast_pack(src0, src1, dst,
+ (const half *) src0->data, (const half *) src1->data, (half *) dst->data,
+ stream, std::make_index_sequence{});
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
+ launch_bin_bcast_pack(src0, src1, dst,
+ (const half *) src0->data, (const float *) src1->data, (half *) dst->data,
+ stream, std::make_index_sequence{});
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+ launch_bin_bcast_pack(src0, src1, dst,
+ (const half *) src0->data, (const float *) src1->data, (float *) dst->data,
+ stream, std::make_index_sequence{});
+ } else {
+ fprintf(stderr,
+ "%s: unsupported types for fusion: dst: %s, src0: %s, src1: %s\n",
+ __func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
+ GGML_ABORT("fatal error");
+ }
+}
+
+
+void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
+ GGML_ASSERT(2 <= n_fuse && n_fuse <= 8);
+
+ switch (n_fuse) {
+ case 2:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ case 3:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ case 4:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ case 5:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ case 6:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ case 7:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ case 8:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ default:
+ GGML_ASSERT(false && "Unsupported n_fuse value");
+ }
+}
+
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
diff --git a/ggml/src/ggml-cuda/binbcast.cuh b/ggml/src/ggml-cuda/binbcast.cuh
index 3ac1c9b03f..62bc950111 100644
--- a/ggml/src/ggml-cuda/binbcast.cuh
+++ b/ggml/src/ggml-cuda/binbcast.cuh
@@ -7,3 +7,5 @@ void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 767ad83f60..85bc9e933b 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -107,9 +107,9 @@ constexpr bool ggml_cuda_has_arch(const int arch) {
return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
}
-constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) {
+constexpr int ggml_cuda_highest_compiled_arch_impl(const int /*arch*/, const int cur) {
if (cur == 0) {
- GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch);
+ return -1;
}
return cur;
}
@@ -420,16 +420,28 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
template
static __device__ __forceinline__ int warp_reduce_all(int x) {
-#ifdef GGML_USE_HIP
+ if (width == ggml_cuda_get_physical_warp_size()) {
+ return __all_sync(0xffffffff, x);
+ } else {
#pragma unroll
- for (int offset = width/2; offset > 0; offset >>= 1) {
- x = x && __shfl_xor_sync(0xffffffff, x, offset, width);
+ for (int offset = width/2; offset > 0; offset >>= 1) {
+ x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
+ }
+ return x;
+ }
+}
+
+template
+static __device__ __forceinline__ int warp_reduce_any(int x) {
+ if (width == ggml_cuda_get_physical_warp_size()) {
+ return __any_sync(0xffffffff, x);
+ } else {
+#pragma unroll
+ for (int offset = width/2; offset > 0; offset >>= 1) {
+ x = __shfl_xor_sync(0xffffffff, x, offset, width) || x;
+ }
+ return x;
}
- return x;
-#else
- static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented");
- return __all_sync(0xffffffff, x);
-#endif // GGML_USE_HIP
}
template
diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu
new file mode 100644
index 0000000000..cf878d1fd1
--- /dev/null
+++ b/ggml/src/ggml-cuda/conv2d.cu
@@ -0,0 +1,171 @@
+#include "conv2d.cuh"
+
+struct conv_params {
+ const int64_t IW, IH;
+ const int64_t OW, OH;
+ const int64_t KW, KH;
+ const int64_t ST_X, ST_Y;
+ const int64_t PD_X, PD_Y;
+ const int64_t DL_X, DL_Y;
+ const int64_t IC, OC;
+ const int64_t B;
+ const int64_t TOTAL;
+};
+
+struct kernel_bounds {
+ int64_t y_min, y_max;
+ int64_t x_min, x_max;
+};
+
+__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) {
+ return (a > b) ? a : b;
+}
+
+__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) {
+ return (a < b) ? a : b;
+}
+
+__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) {
+ kernel_bounds bounds;
+ bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
+ bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
+ bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
+ bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
+ return bounds;
+}
+
+__device__ __forceinline__ int calculate_input_coord(int64_t out_coord,
+ int64_t kern_coord,
+ int64_t stride,
+ int64_t dilation,
+ int64_t padding) {
+ return out_coord * stride + kern_coord * dilation - padding;
+}
+
+struct whcn_layout {
+ __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
+ return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x;
+ }
+
+ __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) {
+ return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx;
+ }
+
+ __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
+ return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x;
+ }
+
+ __device__ static void unpack_indices(int64_t global_idx,
+ const conv_params & P,
+ int64_t & n,
+ int64_t & c,
+ int64_t & out_y,
+ int64_t & out_x) {
+ out_x = global_idx % P.OW;
+ out_y = (global_idx / P.OW) % P.OH;
+ c = (global_idx / (P.OW * P.OH)) % P.OC;
+ n = global_idx / (P.OW * P.OH * P.OC);
+ }
+};
+
+template
+static __global__ void conv2d_kernel(const float * __restrict__ input,
+ const T * __restrict__ kernel,
+ float * __restrict__ output,
+ const conv_params P) {
+ const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (global_idx >= P.TOTAL) {
+ return;
+ }
+
+ int64_t n, c_out, out_y, out_x;
+ Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);
+
+ T acc = 0;
+
+ for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
+ kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
+
+ for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) {
+ const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y);
+
+ for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
+ const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
+
+ T input_val;
+ if (std::is_same::value) {
+ input_val = __float2half(input[Layout::input_index(n, c_in, in_y, in_x, P)]);
+ } else {
+ input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
+ }
+
+ T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
+ acc += (input_val * kernel_val);
+ }
+ }
+ }
+
+ // [N, OC, OH, OW]
+ output[Layout::output_index(n, c_out, out_y, out_x, P)] = (float) acc;
+}
+
+template
+static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
+ const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
+ conv2d_kernel<<>>(X_D, K_D, Y_D, P);
+}
+
+static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
+ conv2d_cuda(X_D, K_D, Y_D, P, st);
+}
+
+static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
+ conv2d_cuda(X_D, K_D, Y_D, P, st);
+}
+
+void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * kernel = dst->src[0];
+ const ggml_tensor * input = dst->src[1];
+ float * K_D = (float *) kernel->data;
+ const float * X_D = (const float *) input->data;
+ float * Y_D = (float *) dst->data;
+
+ GGML_ASSERT(ggml_is_contiguous(kernel));
+ GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32);
+
+ // same number of input channels
+ GGML_ASSERT(input->ne[2] == kernel->ne[2]);
+
+ cudaStream_t st = ctx.stream();
+
+ const int32_t * p = (const int32_t *) dst->op_params;
+ const int ST_X = p[0]; // stride_x
+ const int ST_Y = p[1]; // stride_y
+ const int PD_X = p[2]; // padding_x
+ const int PD_Y = p[3]; // padding_y
+ const int DL_X = p[4]; // dilation_x
+ const int DL_Y = p[5]; // dilation_y
+
+ // No cwhn
+ GGML_ASSERT(p[6] == false);
+
+ const int IW = input->ne[0]; // input_w
+ const int IH = input->ne[1]; // input_h
+ const int OW = dst->ne[0]; // output_w
+ const int OH = dst->ne[1]; // output_h
+ const int KW = kernel->ne[0]; // kernel_w
+ const int KH = kernel->ne[1]; // kernel_h
+ const int IC = input->ne[2]; // input_channels
+ const int OC = kernel->ne[3]; // ouptut_chanles
+ const int B = input->ne[3]; // n_batches
+
+ const int64_t total = B * OC * OH * OW;
+ conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };
+
+ if (kernel->type == GGML_TYPE_F16) {
+ conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st);
+ } else {
+ conv2d_cuda_f32(X_D, K_D, Y_D, params, st);
+ }
+}
diff --git a/ggml/src/ggml-cuda/conv2d.cuh b/ggml/src/ggml-cuda/conv2d.cuh
new file mode 100644
index 0000000000..ce4802c7ed
--- /dev/null
+++ b/ggml/src/ggml-cuda/conv2d.cuh
@@ -0,0 +1,5 @@
+#pragma once
+#include "common.cuh"
+
+#define CUDA_CONV2D_BLOCK_SIZE 256
+void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu
index 6239d184d0..a900799a99 100644
--- a/ggml/src/ggml-cuda/fattn-tile-f16.cu
+++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu
@@ -258,7 +258,7 @@ static __global__ void flash_attn_tile_ext_f16(
const half val = hexp(sink - kqmax[j0/nwarps]);
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
if (threadIdx.x == 0) {
- kqsum[j0/nwarps].x = __hadd(kqsum[j0/nwarps].x, val);
+ kqsum[j0/nwarps].x = __hadd(__low2half(kqsum[j0/nwarps]), val);
}
#pragma unroll
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 4e17fd211e..e06f95f081 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -12,6 +12,7 @@
#include "ggml-cuda/clamp.cuh"
#include "ggml-cuda/concat.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"
+#include "ggml-cuda/conv2d.cuh"
#include "ggml-cuda/conv2d-dw.cuh"
#include "ggml-cuda/conv2d-transpose.cuh"
#include "ggml-cuda/convert.cuh"
@@ -49,6 +50,7 @@
#include "ggml-cuda/wkv.cuh"
#include "ggml-cuda/gla.cuh"
#include "ggml-cuda/set-rows.cuh"
+#include "ggml-cuda/pad_reflect_1d.cuh"
#include "ggml.h"
#include
@@ -203,6 +205,8 @@ static ggml_cuda_device_info ggml_cuda_init() {
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
#endif // GGML_CUDA_FORCE_CUBLAS
GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
+
+ std::vector> turing_devices_without_mma;
for (int id = 0; id < info.device_count; ++id) {
int device_vmm = 0;
@@ -260,7 +264,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
info.devices[id].cc = 100*prop.major + 10*prop.minor;
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
-#endif // defined(GGML_USE_HIP)
+ std::string device_name(prop.name);
+ if (device_name == "NVIDIA GeForce MX450") {
+ turing_devices_without_mma.push_back({ id, device_name });
+ } else if (device_name == "NVIDIA GeForce MX550") {
+ turing_devices_without_mma.push_back({ id, device_name });
+ } else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
+ turing_devices_without_mma.push_back({ id, device_name });
+ }
+#endif // defined(GGML_USE_HIP)
+ }
+
+ if (ggml_cuda_highest_compiled_arch(GGML_CUDA_CC_TURING) >= GGML_CUDA_CC_TURING && !turing_devices_without_mma.empty()) {
+ GGML_LOG_INFO("The following devices will have suboptimal performance due to a lack of tensor cores:\n");
+ for (size_t device_pos = 0; device_pos < turing_devices_without_mma.size(); device_pos++) {
+ GGML_LOG_INFO(
+ " Device %d: %s\n", turing_devices_without_mma[device_pos].first, turing_devices_without_mma[device_pos].second.c_str());
+ }
+ GGML_LOG_INFO(
+ "Consider compiling with CMAKE_CUDA_ARCHITECTURES=61-virtual;80-virtual and DGGML_CUDA_FORCE_MMQ to force the use of the Pascal code for Turing.\n");
}
for (int id = 0; id < info.device_count; ++id) {
@@ -2352,6 +2374,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_PAD:
ggml_cuda_op_pad(ctx, dst);
break;
+ case GGML_OP_PAD_REFLECT_1D:
+ ggml_cuda_op_pad_reflect_1d(ctx, dst);
+ break;
case GGML_OP_ARANGE:
ggml_cuda_op_arange(ctx, dst);
break;
@@ -2427,6 +2452,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst);
break;
+ case GGML_OP_CONV_2D:
+ ggml_cuda_op_conv2d(ctx, dst);
+ break;
case GGML_OP_CONV_2D_DW:
ggml_cuda_op_conv2d_dw(ctx, dst);
break;
@@ -2793,9 +2821,14 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false;
}
- if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
+ if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
+ const ggml_tensor *add = nullptr;
+
+ if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {
+ add = cgraph->nodes[node_idx+2];
+ }
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
@@ -2807,6 +2840,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false;
}
+ if (add && (add->src[0]->type != GGML_TYPE_F32 ||
+ add->src[1]->type != GGML_TYPE_F32 ||
+ add->type != GGML_TYPE_F32) ) {
+ return false;
+ }
+
//if rms norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
return false;
@@ -2817,6 +2856,10 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false;
}
+ if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) {
+ return false;
+ }
+
return true;
}
@@ -2863,7 +2906,46 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) {
- if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
+
+ if (node->op == GGML_OP_ADD) {
+ int n_fuse = 0;
+ ggml_op ops[8];
+ std::fill(ops, ops + 8, GGML_OP_ADD);
+
+ for (; n_fuse <= 6; ++n_fuse){
+ if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
+ break;
+ }
+ if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
+ break;
+ }
+ if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
+ break;
+ }
+ }
+
+ n_fuse++;
+
+ if (n_fuse > 1) {
+ for (int j = 0; j < n_fuse - 1; ++j) {
+ node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
+ }
+ cgraph->nodes[i + n_fuse - 1]->data = node->data;
+ ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
+ i += n_fuse - 1;
+
+ continue;
+ }
+ }
+
+
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
+ ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
+ i += 2;
+ continue;
+ }
+
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
i++;
continue;
@@ -3082,7 +3164,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
return false;
}
-#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
+#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) || defined(GGML_USE_HIP)
cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
if (err != cudaSuccess) {
// clear the error
@@ -3477,19 +3559,21 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
}
case GGML_OP_IM2COL:
+ case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
- case GGML_OP_SUM_ROWS:
- case GGML_OP_MEAN:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:
return true;
+ case GGML_OP_SUM_ROWS:
+ case GGML_OP_MEAN:
case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
+ case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
index 576032a0ce..714b23f9f4 100644
--- a/ggml/src/ggml-cuda/mmq.cu
+++ b/ggml/src/ggml-cuda/mmq.cu
@@ -3,6 +3,140 @@
#include
+// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
+struct mmq_ids_helper_store {
+ uint32_t data;
+
+ __device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
+ data = (it & 0x003FFFFF) | (iex_used << 22);
+ }
+
+ __device__ uint32_t it() const {
+ return data & 0x003FFFFF;
+ }
+
+ __device__ uint32_t iex_used() const {
+ return data >> 22;
+ }
+};
+static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
+
+// Helper function for mul_mat_id, converts ids to a more convenient format.
+// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
+// ids_dst describes the same mapping but for the dst tensor.
+// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
+template
+__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
+static __global__ void mmq_ids_helper(
+ const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
+ const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+ const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
+ const int expert = blockIdx.x;
+
+ extern __shared__ char data_mmq_ids_helper[];
+ mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
+
+ int nex_prev = 0; // Number of columns for experts with a lower index.
+ int it_compact = 0; // Running index for the compact slice of this expert.
+
+ if constexpr (n_expert_used_template == 0) {
+ // Generic implementation:
+ for (int it = 0; it < n_tokens; ++it) {
+ int iex_used = -1; // The index at which the expert is used, if any.
+ for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
+ const int expert_used = ids[it*si1 + iex];
+ nex_prev += expert_used < expert;
+ if (expert_used == expert) {
+ iex_used = iex;
+ }
+ }
+
+ if (iex_used != -1) {
+ store[it_compact] = mmq_ids_helper_store(it, iex_used);
+ }
+
+ if (warp_reduce_any(iex_used != -1)) {
+ it_compact++;
+ }
+ }
+ } else {
+ // Implementation optimized for specific numbers of experts used:
+ static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
+ const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
+ for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
+ const int it = it0 + threadIdx.x / neu_padded;
+
+ const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
+ const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
+ ids[it*si1 + iex] : INT_MAX;
+ const int iex_used = expert_used == expert ? iex : -1;
+ nex_prev += expert_used < expert;
+
+ // Whether the threads at this token position have used the expert:
+ const int it_compact_add_self = warp_reduce_any(iex_used != -1);
+
+ // Do a scan over threads at lower token positions in warp to get the correct index for writing data:
+ int it_compact_add_lower = 0;
+#pragma unroll
+ for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
+ const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
+ if (threadIdx.x >= offset) {
+ it_compact_add_lower += tmp;
+ }
+ }
+
+ if (iex_used != -1) {
+ store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used);
+ }
+
+ // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
+ it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
+ }
+ }
+ nex_prev = warp_reduce_sum(nex_prev);
+
+ for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
+ const mmq_ids_helper_store store_it = store[itc];
+ const int it = store_it.it();
+ const int iex_used = store_it.iex_used();
+ ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
+ ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
+ }
+
+ if (threadIdx.x != 0) {
+ return;
+ }
+
+ expert_bounds[expert] = nex_prev;
+
+ if (expert < gridDim.x - 1) {
+ return;
+ }
+
+ expert_bounds[gridDim.x] = nex_prev + it_compact;
+}
+
+template
+static void launch_mmq_ids_helper(
+ const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
+ const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
+ GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store");
+ GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store");
+
+ const int id = ggml_cuda_get_device();
+ const int warp_size = ggml_cuda_info().devices[id].warp_size;
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+ CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper, smpbo);
+
+ const dim3 num_blocks(n_experts, 1, 1);
+ const dim3 block_size(warp_size, 1, 1);
+ const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
+ GGML_ASSERT(nbytes_shared <= smpbo);
+ mmq_ids_helper<<>>
+ (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
+}
+
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
switch (args.type_x) {
case GGML_TYPE_Q4_0:
@@ -137,7 +271,7 @@ void ggml_cuda_mul_mat_q(
ne00, ne01, ne1, s01, ne11, s1,
ne02, ne12, s02, s12, s2,
ne03, ne13, s03, s13, s3,
- use_stream_k};
+ use_stream_k, ne1};
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
return;
}
@@ -148,54 +282,50 @@ void ggml_cuda_mul_mat_q(
const int64_t n_expert_used = ids->ne[0];
const int64_t ne_get_rows = ne12 * n_expert_used;
+ GGML_ASSERT(ne1 == n_expert_used);
- std::vector ids_host(ggml_nbytes(ids));
- std::vector ids_src1_host;
- ids_src1_host.reserve(ne_get_rows);
- std::vector ids_dst_host;
- ids_dst_host.reserve(ne_get_rows);
- std::vector tokens_per_expert_host(ne02);
- std::vector expert_bounds_host(ne02 + 1);
- ggml_cuda_pool_alloc ids_buf_dev(ctx.pool());
+ ggml_cuda_pool_alloc ids_src1(ctx.pool(), ne_get_rows);
+ ggml_cuda_pool_alloc ids_dst(ctx.pool(), ne_get_rows);
+ ggml_cuda_pool_alloc expert_bounds(ctx.pool(), ne02 + 1);
- CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
- CUDA_CHECK(cudaStreamSynchronize(stream));
+ {
+ GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));
+ const int si1 = ids->nb[1] / ggml_element_size(ids);
+ const int sis1 = nb12 / nb11;
- for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices
- for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens
- for (int64_t iex = 0; iex < n_expert_used; ++iex) {
- const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]);
- assert(expert_to_use >= 0 && expert_to_use < ne02);
- if (expert_to_use == i02) {
- ids_src1_host.push_back(i12*(nb12/nb11) + iex % ne11);
- ids_dst_host.push_back(i12*ne1 + iex);
- tokens_per_expert_host[i02]++;
- break;
- }
- }
+ switch (n_expert_used) {
+ case 2:
+ launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
+ case 4:
+ launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
+ case 6:
+ launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
+ case 8:
+ launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
+ case 16:
+ launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
+ case 32:
+ launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
+ default:
+ launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
}
+ CUDA_CHECK(cudaGetLastError());
}
- int32_t cumsum = 0;
- for (int64_t i = 0; i < ne02; ++i) {
- expert_bounds_host[i] = cumsum;
- cumsum += tokens_per_expert_host[i];
- }
- expert_bounds_host[ne02] = cumsum;
-
- std::vector ids_buf_host;
- ids_buf_host.reserve(ids_src1_host.size() + ids_dst_host.size() + expert_bounds_host.size());
- ids_buf_host.insert(ids_buf_host.end(), ids_src1_host.begin(), ids_src1_host.end());
- ids_buf_host.insert(ids_buf_host.end(), ids_dst_host.begin(), ids_dst_host.end());
- ids_buf_host.insert(ids_buf_host.end(), expert_bounds_host.begin(), expert_bounds_host.end());
- ids_buf_dev.alloc(ids_buf_host.size() + get_mmq_x_max_host(cc)); // Expert bounds are padded on device.
- CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_buf_host.data(), ids_buf_host.size()*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
- CUDA_CHECK(cudaStreamSynchronize(stream));
-
- const int32_t * ids_src1_dev = ids_buf_dev.ptr;
- const int32_t * ids_dst_dev = ids_src1_dev + ids_src1_host.size();
- const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size();
-
const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
ggml_cuda_pool_alloc src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
@@ -208,7 +338,7 @@ void ggml_cuda_mul_mat_q(
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[2] / ts_src1;
- quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
+ quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type,
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
CUDA_CHECK(cudaGetLastError());
}
@@ -218,11 +348,11 @@ void ggml_cuda_mul_mat_q(
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
const mmq_args args = {
- src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d,
+ src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d,
ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
ne02, ne02, s02, s12, s2,
ne03, ne13, s03, s13, s3,
- use_stream_k};
+ use_stream_k, ne12};
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
}
@@ -262,7 +392,7 @@ void ggml_cuda_op_mul_mat_q(
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
1, 1, 0, 0, 0,
1, 1, 0, 0, 0,
- use_stream_k};
+ use_stream_k, src1_ncols};
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index 650f708067..c9a07e82fe 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -3138,7 +3138,8 @@ static __global__ void mul_mat_q(
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
- const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+ const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+ const int ncols_max) {
// Skip unused template specializations for faster compilation:
if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
@@ -3152,7 +3153,7 @@ static __global__ void mul_mat_q(
constexpr int qk = ggml_cuda_type_traits::qk;
constexpr int mmq_y = get_mmq_y_device();
- const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; // Number of tiles x
+ const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
// Initialize the ids for writing back data with just the index.
@@ -3376,7 +3377,8 @@ template
static __global__ void mul_mat_q_stream_k_fixup(
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
- const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
+ const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
+ const int ncols_max) {
constexpr int mmq_y = get_mmq_y_device();
constexpr int qk = ggml_cuda_type_traits::qk;
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
@@ -3387,7 +3389,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
- const int ntx = (ncols_dst + mmq_x - 1) / mmq_x;
+ const int ntx = (ncols_max + mmq_x - 1) / mmq_x;
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
const int bidx0 = blockIdx.x;
@@ -3528,7 +3530,7 @@ struct mmq_args {
int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
- bool use_stream_k;
+ bool use_stream_k; int64_t ncols_max;
};
template
@@ -3558,7 +3560,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared);
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
- const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
+ const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x;
const int ntzw = args.nchannels_y * args.nsamples_y;
const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
@@ -3574,14 +3576,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
+ args.ncols_max);
} else {
constexpr bool need_check = true;
mul_mat_q<<>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
+ args.ncols_max);
}
return;
}
@@ -3601,7 +3605,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
+ args.ncols_max);
if (!fixup_needed) {
return;
@@ -3609,14 +3614,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
mul_mat_q_stream_k_fixup<<>>
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
- args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
+ args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
+ args.ncols_max);
} else {
constexpr bool need_check = true;
mul_mat_q<<>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
+ args.ncols_max);
if (!fixup_needed) {
return;
@@ -3624,7 +3631,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
mul_mat_q_stream_k_fixup<<>>
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
- args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
+ args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
+ args.ncols_max);
}
}
@@ -3649,7 +3657,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
continue;
}
- const int ntiles_x = (args.ncols_y + mmq_x - 1) / mmq_x;
+ const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x;
if (ntiles_x < ntiles_x_best) {
mmq_x_best = mmq_x;
diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu
index bddcca51b7..d5157d958b 100644
--- a/ggml/src/ggml-cuda/norm.cu
+++ b/ggml/src/ggml-cuda/norm.cu
@@ -104,12 +104,30 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
}
}
-template
-static __global__ void rms_norm_f32(
- const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
- const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0,
- const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0,
- const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) {
+template
+static __global__ void rms_norm_f32(const float * x, float * dst,
+ const int ncols,
+ const int64_t stride_row,
+ const int64_t stride_channel,
+ const int64_t stride_sample,
+ const float eps,
+ const float * mul = nullptr,
+ const int64_t mul_stride_row = 0,
+ const int64_t mul_stride_channel = 0,
+ const int64_t mul_stride_sample = 0,
+ const int mul_ncols = 0,
+ const int mul_nrows = 0,
+ const int mul_nchannels = 0,
+ const int mul_nsamples = 0,
+ const float * add = nullptr,
+ const int64_t add_stride_row = 0,
+ const int64_t add_stride_channel = 0,
+ const int64_t add_stride_sample = 0,
+ const int add_ncols = 0,
+ const int add_nrows = 0,
+ const int add_nchannels = 0,
+ const int add_nsamples = 0) {
+
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
@@ -118,6 +136,8 @@ static __global__ void rms_norm_f32(
const int sample = blockIdx.z;
const int tid = threadIdx.x;
+ static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying");
+
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
@@ -128,6 +148,13 @@ static __global__ void rms_norm_f32(
mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
}
+ if constexpr (do_add) {
+ const int add_row = row % add_nrows;
+ const int add_channel = channel % add_nchannels;
+ const int add_sample = sample % add_nsamples;
+ add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
+ }
+
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
@@ -154,7 +181,11 @@ static __global__ void rms_norm_f32(
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
- if constexpr (do_multiply) {
+ if constexpr (do_multiply && do_add) {
+ const int mul_col = col % mul_ncols;
+ const int add_col = col % add_ncols;
+ dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
+ } else if constexpr (do_multiply) {
const int mul_col = col % mul_ncols;
dst[col] = scale * x[col] * mul[mul_col];
} else {
@@ -331,23 +362,70 @@ static void rms_norm_f32_cuda(
}
}
-static void rms_norm_mul_f32_cuda(
- const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
- const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
- const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
- const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples,
- const float eps, cudaStream_t stream) {
+static void rms_norm_mul_f32_cuda(const float * x,
+ const float * mul,
+ const float * add,
+ float * dst,
+ const int ncols,
+ const int nrows,
+ const int nchannels,
+ const int nsamples,
+ const int64_t stride_row,
+ const int64_t stride_channel,
+ const int64_t stride_sample,
+ const int64_t mul_stride_row,
+ const int64_t mul_stride_channel,
+ const int64_t mul_stride_sample,
+ const int mul_ncols,
+ const int mul_nrows,
+ const int mul_nchannels,
+ const int mul_nsamples,
+ const int64_t add_stride_row,
+ const int64_t add_stride_channel,
+ const int64_t add_stride_sample,
+ const int add_ncols,
+ const int add_nrows,
+ const int add_nchannels,
+ const int add_nsamples,
+ const float eps,
+ cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (mul == nullptr) {
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
return;
}
- if (ncols < 1024) {
- const dim3 block_dims(WARP_SIZE, 1, 1);
- rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+ if (add == nullptr) {
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ rms_norm_f32<<>>(x, dst,
+ ncols, stride_row, stride_channel, stride_sample, eps,
+ mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ rms_norm_f32<1024, true><<>>(x, dst,
+ ncols, stride_row, stride_channel, stride_sample, eps,
+ mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+ }
} else {
- const dim3 block_dims(1024, 1, 1);
- rms_norm_f32<1024, true><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ rms_norm_f32<<>>(x, dst,
+ ncols, stride_row, stride_channel, stride_sample, eps,
+ mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+ add, add_stride_row, add_stride_channel, add_stride_sample,
+ add_ncols, add_nrows, add_nchannels, add_nsamples);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ rms_norm_f32<1024, true, true><<>>(x, dst,
+ ncols, stride_row, stride_channel, stride_sample, eps,
+ mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+ add, add_stride_row, add_stride_channel, add_stride_sample,
+ add_ncols, add_nrows, add_nchannels, add_nsamples);
+ }
}
}
@@ -491,7 +569,102 @@ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor *
const int mul_nchannels = mul_src->ne[2];
const int mul_nsamples = mul_src->ne[3];
- rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream);
+ rms_norm_mul_f32_cuda(src0_d, mul_d, nullptr, dst_d,
+ ne00, ne01, ne02, ne03,
+ /*s00*/ s01, s02, s03,
+ /*mul_s00*/ mul_s01, mul_s02, mul_s03,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+ /*add_s00*/ 0, 0, 0,
+ 0, 0, 0, 0,
+ eps, stream);
+}
+
+void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
+ ggml_tensor * dst,
+ ggml_tensor * mul_tensor,
+ ggml_tensor * add_tensor) {
+ const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
+ float eps = 0.0f;
+
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ const float * src0_d = (const float *) rms_norm_src->data;
+ const float * mul_d = nullptr;
+ const ggml_tensor * mul_src = nullptr;
+
+ if (mul_tensor->src[0] == dst) {
+ mul_d = (float *) mul_tensor->src[1]->data;
+ mul_src = mul_tensor->src[1];
+ } else if (mul_tensor->src[1] == dst) {
+ mul_d = (float *) mul_tensor->src[0]->data;
+ mul_src = mul_tensor->src[0];
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ const float * add_d = nullptr;
+ const ggml_tensor * add_src = nullptr;
+
+ if (add_tensor->src[0] == mul_tensor) {
+ add_d = (float *) add_tensor->src[1]->data;
+ add_src = add_tensor->src[1];
+ } else if (add_tensor->src[1] == mul_tensor) {
+ add_d = (float *) add_tensor->src[0]->data;
+ add_src = add_tensor->src[0];
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ float * dst_d = (float *) add_tensor->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
+ GGML_ASSERT(add_tensor->type == GGML_TYPE_F32);
+ GGML_ASSERT(eps >= 0.0f);
+
+ const int64_t ne00 = rms_norm_src->ne[0];
+ const int64_t ne01 = rms_norm_src->ne[1];
+ const int64_t ne02 = rms_norm_src->ne[2];
+ const int64_t ne03 = rms_norm_src->ne[3];
+
+ const size_t ts0 = ggml_type_size(rms_norm_src->type);
+ GGML_ASSERT(rms_norm_src->nb[0] == ts0);
+ const int64_t s01 = rms_norm_src->nb[1] / ts0;
+ const int64_t s02 = rms_norm_src->nb[2] / ts0;
+ const int64_t s03 = rms_norm_src->nb[3] / ts0;
+
+ const size_t ts_mul = ggml_type_size(mul_src->type);
+ GGML_ASSERT(mul_src->nb[0] == ts_mul);
+ const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
+ const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
+ const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
+
+ const int mul_ncols = mul_src->ne[0];
+ const int mul_nrows = mul_src->ne[1];
+ const int mul_nchannels = mul_src->ne[2];
+ const int mul_nsamples = mul_src->ne[3];
+
+ const size_t ts_add = ggml_type_size(add_src->type);
+ GGML_ASSERT(add_src->nb[0] == ts_add);
+ const int64_t add_s01 = add_src->nb[1] / ts_add;
+ const int64_t add_s02 = add_src->nb[2] / ts_add;
+ const int64_t add_s03 = add_src->nb[3] / ts_add;
+
+ const int add_ncols = add_src->ne[0];
+ const int add_nrows = add_src->ne[1];
+ const int add_nchannels = add_src->ne[2];
+ const int add_nsamples = add_src->ne[3];
+
+ rms_norm_mul_f32_cuda(src0_d, mul_d,add_d,dst_d,
+ ne00,ne01, ne02, ne03,
+ /*s00*/ s01, s02, s03,
+ /*mul_s00*/ mul_s01, mul_s02, mul_s03,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+ /*add_s00*/ add_s01, add_s02, add_s03,
+ add_ncols, add_nrows, add_nchannels, add_nsamples,
+ eps, stream);
}
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh
index 7ea7bd4df3..a74f637672 100644
--- a/ggml/src/ggml-cuda/norm.cuh
+++ b/ggml/src/ggml-cuda/norm.cuh
@@ -8,6 +8,11 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor);
+void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
+ ggml_tensor * dst,
+ ggml_tensor * mul_tensor,
+ ggml_tensor * add_tensor);
+
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/pad_reflect_1d.cu b/ggml/src/ggml-cuda/pad_reflect_1d.cu
new file mode 100644
index 0000000000..4ed34aec3d
--- /dev/null
+++ b/ggml/src/ggml-cuda/pad_reflect_1d.cu
@@ -0,0 +1,82 @@
+#include "pad_reflect_1d.cuh"
+
+static __global__ void pad_reflect_1d_kernel_f32(
+ const void * __restrict__ src0,
+ void * __restrict__ dst,
+ const int64_t ne0,
+ const int64_t ne00,
+ const int64_t ne01,
+ const int64_t ne02,
+ const int64_t ne03,
+ const int64_t nb00,
+ const int64_t nb01,
+ const int64_t nb02,
+ const int64_t nb03,
+ const int64_t nb0,
+ const int64_t nb1,
+ const int64_t nb2,
+ const int64_t nb3,
+ const int p0,
+ const int p1) {
+
+ const int64_t i3 = blockIdx.z;
+ const int64_t i2 = blockIdx.y;
+ const int64_t i1 = blockIdx.x;
+
+ if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
+ return;
+ }
+
+ const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01;
+ char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1;
+
+ for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
+ float value;
+
+ if (i0 < p0) {
+ // Left padding - reflect
+ value = *(const float *)(src0_ptr + (p0 - i0) * nb00);
+ } else if (i0 < ne0 - p1) {
+ // Middle - copy
+ value = *(const float *)(src0_ptr + (i0 - p0) * nb00);
+ } else {
+ // Right padding - reflect
+ int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1;
+ value = *(const float *)(src0_ptr + src_idx * nb00);
+ }
+
+ *(float *)(dst_ptr + i0 * nb0) = value;
+ }
+}
+
+void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ const int32_t * opts = (const int32_t *) dst->op_params;
+ const int p0 = opts[0];
+ const int p1 = opts[1];
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int64_t ne02 = src0->ne[2];
+ const int64_t ne03 = src0->ne[3];
+
+ const int64_t ne0 = dst->ne[0];
+
+ GGML_ASSERT(ne0 == ne00 + p0 + p1);
+
+ const dim3 block_dims(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1, 1);
+ const dim3 grid_dims(ne01, ne02, ne03);
+
+ pad_reflect_1d_kernel_f32<<>>(
+ src0->data, dst->data,
+ ne0, ne00, ne01, ne02, ne03,
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
+ p0, p1
+ );
+}
diff --git a/ggml/src/ggml-cuda/pad_reflect_1d.cuh b/ggml/src/ggml-cuda/pad_reflect_1d.cuh
new file mode 100644
index 0000000000..15f2ed1737
--- /dev/null
+++ b/ggml/src/ggml-cuda/pad_reflect_1d.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_PAD_REFLECT_1D_BLOCK_SIZE 256
+
+void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu
index dc9a7d58d0..6b424381df 100644
--- a/ggml/src/ggml-cuda/ssm-scan.cu
+++ b/ggml/src/ggml-cuda/ssm-scan.cu
@@ -129,7 +129,7 @@ __global__ void __launch_bounds__(d_state, 1)
const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
const int seq_idx = blockIdx.y;
- const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
+ const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh
index d60292b83b..6baab1176f 100644
--- a/ggml/src/ggml-cuda/vecdotq.cuh
+++ b/ggml/src/ggml-cuda/vecdotq.cuh
@@ -28,7 +28,58 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
return ((const int *) x)[i32]; // assume at least 4 byte alignment
}
+// q4 contains 8 indices with 4 bit each.
+// This function selects those bytes from table that are at those indices and returns them as int2.
+// The first int contains the bytes with even indices in q4, the second int contains the bytes with odd indices in q4.
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
+#if defined(GGML_USE_HIP)
+ // Load the 16-byte table into four 32-bit unsigned integers.
+ const uint32_t *values = (const uint32_t *)table;
+
+ const uint32_t q_even = q4;
+ const uint32_t q_odd = (q4 >> 4);
+
+ // Perform lookups in the lower half of the table (indices 0-7).
+ uint32_t v_even_low = __builtin_amdgcn_perm(values[1], values[0], q_even & 0x07070707);
+ uint32_t v_odd_low = __builtin_amdgcn_perm(values[1], values[0], q_odd & 0x07070707);
+
+ // Perform lookups in the upper half of the table (indices 8-15).
+ uint32_t v_even_high = __builtin_amdgcn_perm(values[3], values[2], q_even & 0x07070707);
+ uint32_t v_odd_high = __builtin_amdgcn_perm(values[3], values[2], q_odd & 0x07070707);
+
+ // Select between the low and high results based on the MSB of each index nibble.
+ uint32_t mask_even = 0x03020100 | ((q_even & 0x08080808) >> 1);
+ uint32_t res_x = __builtin_amdgcn_perm(v_even_high, v_even_low, mask_even);
+ uint32_t mask_odd = 0x03020100 | ((q_odd & 0x08080808) >> 1);
+ uint32_t res_y = __builtin_amdgcn_perm(v_odd_high, v_odd_low, mask_odd);
+
+ return make_int2(res_x, res_y);
+#elif !defined(GGML_USE_MUSA)
+ // CUDA does not have an instruction for selecting bytes with 4 bit indices.
+ // However, __byte_perm is an instruction that selects bytes with 3 bit indices that can be used instead.
+ const uint32_t * table32 = (const uint32_t *) table;
+
+ // __byte_perm selects bytes based on the lower 16 bits in its third argument.
+ // Therefore, do 2 iterations over the 32 bits in q4 with 0 and 16 shift.
+ // To handle the fourth bit, first call _byte_perm both for the low and the high 64 bit of table, using the low 3 bits.
+ // Then, call __byte_perm again to select from the low and high bytes based on the fourth bit.
+ uint32_t tmp[2];
+ const uint32_t low_high_selection_indices = (0x32103210 | ((q4 & 0x88888888) >> 1));
+#pragma unroll
+ for (uint32_t i = 0; i < 2; ++i) {
+ const uint32_t shift = 16 * i;
+
+ const uint32_t low = __byte_perm(table32[0], table32[1], q4 >> shift);
+ const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift);
+ tmp[i] = __byte_perm(low, high, low_high_selection_indices >> shift);
+ }
+
+ // tmp contains the bytes from tyble in the same order as the 4 bit indices in q4.
+ // However, for the result we need ints with all even/odd 4 bit indices in q4.
+ // Therefore, 2 more calls to __byte_perm to put the bytes in the correct order.
+ return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531));
+#else
+ // Generic implementation.
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
const int8_t * q0_8 = (const int8_t *) &q0_32;
const char4 val0_8 = make_char4(
@@ -40,6 +91,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
+#endif
}
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index 6e9c67aca0..c6a33d5de3 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -22,7 +22,10 @@
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
+#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
+#define __all_sync(mask, var) __all(var)
+#define __any_sync(mask, var) __any(var)
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
index fc6526d6d5..b9d3639448 100644
--- a/ggml/src/ggml-metal/ggml-metal-impl.h
+++ b/ggml/src/ggml-metal/ggml-metal-impl.h
@@ -249,6 +249,7 @@ typedef struct {
uint64_t nb33;
int32_t ne1;
int32_t ne2;
+ int32_t ne3;
float scale;
float max_bias;
float m0;
@@ -257,6 +258,11 @@ typedef struct {
float logit_softcap;
} ggml_metal_kargs_flash_attn_ext;
+typedef struct {
+ int32_t nrows;
+ int32_t ne20;
+} ggml_metal_kargs_flash_attn_ext_reduce;
+
typedef struct {
int32_t ne00;
int32_t ne02;
@@ -320,40 +326,31 @@ typedef struct {
} ggml_metal_kargs_mul_mv_ext;
typedef struct {
+ int32_t ne02;
int32_t ne10;
int32_t ne11; // n_expert_used (bcast)
uint64_t nb11;
uint64_t nb12;
- int32_t neh11; // n_tokens
- uint64_t nbh11;
+ int32_t ne21; // n_tokens
int32_t ne20; // n_expert_used
uint64_t nb21;
} ggml_metal_kargs_mul_mm_id_map0;
-typedef struct {
- int32_t ne20; // n_expert_used
- int32_t neh0;
- int32_t neh1;
- uint64_t nbh1;
- uint64_t nbh2;
- int32_t ne0;
- uint64_t nb1;
- uint64_t nb2;
-} ggml_metal_kargs_mul_mm_id_map1;
-
typedef struct {
int32_t ne00;
int32_t ne02;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
- int32_t neh12;
- uint64_t nbh10;
- uint64_t nbh11;
- uint64_t nbh12;
- uint64_t nbh13;
- int32_t neh0;
- int32_t neh1;
+ int32_t ne11;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ int32_t ne20;
+ int32_t ne21;
+ int32_t ne0;
+ int32_t ne1;
int16_t r2;
int16_t r3;
} ggml_metal_kargs_mul_mm_id;
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index 7c70d352df..1f93633d91 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -93,35 +93,37 @@ static id ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
if (ctx->mtl_device == nil) {
ctx->mtl_device = MTLCreateSystemDefaultDevice();
- ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
- ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
+ if (ctx->mtl_device) {
+ ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
+ ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
- ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
+ ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
- ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
+ ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
#endif
- ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
- ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
+ ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
+ ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
#if defined(GGML_METAL_USE_BF16)
- ctx->use_bfloat = ctx->has_bfloat;
+ ctx->use_bfloat = ctx->has_bfloat;
#else
- ctx->use_bfloat = false;
+ ctx->use_bfloat = false;
#endif
- ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
+ ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
- {
- const char * val = getenv("GGML_METAL_FUSION_DEBUG");
- ctx->debug_fusion = val ? atoi(val) : 0;
+ {
+ const char * val = getenv("GGML_METAL_FUSION_DEBUG");
+ ctx->debug_fusion = val ? atoi(val) : 0;
+ }
+
+ memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
+
+ ctx->max_size = ctx->mtl_device.maxBufferLength;
+
+ strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
}
-
- memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
-
- ctx->max_size = ctx->mtl_device.maxBufferLength;
-
- strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
}
ctx->mtl_device_ref_count++;
@@ -289,6 +291,10 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
@@ -396,8 +402,12 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
@@ -443,6 +453,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
@@ -452,6 +463,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
@@ -461,6 +473,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
@@ -470,6 +483,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
@@ -479,6 +493,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
@@ -488,6 +503,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
@@ -497,6 +513,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
@@ -506,6 +523,13 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
@@ -555,6 +579,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE,
GGML_METAL_KERNEL_TYPE_SET_I32,
GGML_METAL_KERNEL_TYPE_SET_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
@@ -1304,6 +1329,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2, mul_mv_ext_f32_f32_r1_2, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3, mul_mv_ext_f32_f32_r1_3, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4, mul_mv_ext_f32_f32_r1_4, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5, mul_mv_ext_f32_f32_r1_5, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
@@ -1412,8 +1441,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1, mul_mm_id_map0_f16_ne20_1, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2, mul_mm_id_map0_f16_ne20_2, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
@@ -1459,6 +1492,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40, flash_attn_ext_f16_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
@@ -1468,6 +1502,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40, flash_attn_ext_bf16_h40, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
@@ -1477,6 +1512,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40, flash_attn_ext_q4_0_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
@@ -1486,6 +1522,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40, flash_attn_ext_q4_1_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
@@ -1495,6 +1532,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40, flash_attn_ext_q5_0_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
@@ -1504,6 +1542,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40, flash_attn_ext_q5_1_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
@@ -1513,6 +1552,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40, flash_attn_ext_q8_0_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
@@ -1522,6 +1562,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40, flash_attn_ext_vec_f16_h40, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40, flash_attn_ext_vec_bf16_h40, has_simdgroup_reduction && use_bfloat);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40, flash_attn_ext_vec_q4_0_h40, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40, flash_attn_ext_vec_q4_1_h40, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40, flash_attn_ext_vec_q5_0_h40, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40, flash_attn_ext_vec_q5_1_h40, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40, flash_attn_ext_vec_q8_0_h40, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
@@ -1571,6 +1618,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, flash_attn_ext_reduce, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
@@ -1846,7 +1894,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_ROPE:
return true;
case GGML_OP_IM2COL:
- return op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
+ return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
case GGML_OP_POOL_1D:
return false;
case GGML_OP_UPSCALE:
@@ -3347,15 +3395,16 @@ static int ggml_metal_encode_node(
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel
- const int ne11_mm_min = 4;
+ const int ne11_mm_min = 8;
// first try to use small-batch mat-mv kernels
// these should be efficient for BS [2, ~8]
- if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
+ if (src1t == GGML_TYPE_F32 && (ne00%128 == 0) &&
(
(
(
- src0t == GGML_TYPE_F16 || // TODO: helper function
+ src0t == GGML_TYPE_F32 || // TODO: helper function
+ src0t == GGML_TYPE_F16 ||
src0t == GGML_TYPE_Q4_0 ||
src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q5_0 ||
@@ -3383,7 +3432,17 @@ static int ggml_metal_encode_node(
// values and there can be some tail effects when nsg is high. need to confirm this
//
const int nsg = 2; // num simdgroups per threadgroup
- const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
+
+ // num threads along row per simdgroup
+ int nxpsg = 0;
+ if (ne00 % 256 == 0 && ne11 < 3) {
+ nxpsg = 16;
+ } else if (ne00 % 128 == 0) {
+ nxpsg = 8;
+ } else {
+ nxpsg = 4;
+ }
+
const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
int r1ptg = 4; // num src1 rows per threadgroup
@@ -3406,6 +3465,14 @@ static int ggml_metal_encode_node(
id pipeline = nil;
switch (src0->type) {
+ case GGML_TYPE_F32:
+ switch (r1ptg) {
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2].pipeline; break;
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3].pipeline; break;
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4].pipeline; break;
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ } break;
case GGML_TYPE_F16:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
@@ -3560,7 +3627,7 @@ static int ggml_metal_encode_node(
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
- case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
@@ -3878,38 +3945,6 @@ static int ggml_metal_encode_node(
default: break;
}
- const int64_t neh10 = ne10; // n_embd
- const int64_t neh11 = ne21; // n_tokens
- const int64_t neh12 = ne02; // n_expert
-
- const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16);
- const uint64_t nbh11 = nbh10*neh10;
- const uint64_t nbh12 = nbh11*neh11;
- const uint64_t nbh13 = nbh12*neh12;
-
- const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12;
- id h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
- if (!h_src1) {
- GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
- return 0;
- }
-
- const int64_t neh0 = ne0;
- const int64_t neh1 = ne21;
- const int64_t neh2 = ne02;
-
- const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32);
- const uint64_t nbh1 = nbh0*neh0;
- const uint64_t nbh2 = nbh1*neh1;
- //const uint64_t nbh3 = nbh2*neh2;
-
- const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2;
- id h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
- if (!h_dst) {
- GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
- return 0;
- }
-
// tokens per expert
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
id h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
@@ -3919,8 +3954,8 @@ static int ggml_metal_encode_node(
}
// id map
- // [n_expert_used, n_tokens]
- const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21;
+ // [n_tokens, n_expert]
+ const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
id h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
if (!h_ids) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
@@ -3928,32 +3963,45 @@ static int ggml_metal_encode_node(
}
{
- const int nth = MIN(1024, ne10/4);
-
ggml_metal_kargs_mul_mm_id_map0 args = {
+ ne02,
ne10,
- ne11, // n_expert_used (bcast)
+ ne11, // n_expert_used (bcast)
nb11,
nb12,
- neh11, // n_tokens
- nbh11,
- ne20, // n_expert_used
+ ne21, // n_tokens
+ ne20, // n_expert_used
nb21,
};
id pipeline = nil;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
+ pipeline = nil;
+
+ switch (ne20) {
+ case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1 ].pipeline; break;
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2 ].pipeline; break;
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break;
+ case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break;
+ case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break;
+ case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break;
+ default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20);
+ }
+
+ GGML_ASSERT(ne02 <= (int) pipeline.maxTotalThreadsPerThreadgroup);
+
+ const size_t smem = ne02*ne20*sizeof(uint16_t);
+
+ GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
- [encoder setBuffer: h_src1 offset:0 atIndex:3];
- [encoder setBuffer: h_tpe offset:0 atIndex:4];
- [encoder setBuffer: h_ids offset:0 atIndex:5];
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:1];
+ [encoder setBuffer: h_tpe offset:0 atIndex:2];
+ [encoder setBuffer: h_ids offset:0 atIndex:3];
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
}
{
@@ -3992,13 +4040,15 @@ static int ggml_metal_encode_node(
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
- /*.neh12 =*/ neh12,
- /*.nbh10 =*/ nbh10,
- /*.nbh11 =*/ nbh11,
- /*.nbh12 =*/ nbh12,
- /*.nbh13 =*/ nbh13,
- /*.neh0 =*/ neh0,
- /*.neh1 =*/ neh1,
+ /*.ne11 =*/ ne11, // n_expert_used (bcast)
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ne20 =*/ ne20, // n_expert_used
+ /*.ne21 =*/ ne21, // n_tokens
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
/*.r2 =*/ r2,
/*.r3 =*/ r3,
};
@@ -4006,42 +4056,14 @@ static int ggml_metal_encode_node(
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
- [encoder setBuffer: h_src1 offset:0 atIndex:2];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer: h_tpe offset:0 atIndex:3];
- [encoder setBuffer: h_dst offset:0 atIndex:4];
+ [encoder setBuffer: h_ids offset:0 atIndex:4];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:5];
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
}
-
- {
- GGML_ASSERT(ne0 % 4 == 0);
-
- const int nth = MIN(1024, ne0/4);
-
- ggml_metal_kargs_mul_mm_id_map1 args = {
- ne20, // n_expert_used
- neh0,
- neh1,
- nbh1,
- nbh2,
- ne0,
- nb1,
- nb2,
- };
-
- id pipeline = nil;
-
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBytes:&args length:sizeof(args) atIndex:0];
- [encoder setBuffer: h_dst offset:0 atIndex:1];
- [encoder setBuffer: h_ids offset:0 atIndex:2];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- }
} else {
id pipeline = nil;
@@ -4701,7 +4723,6 @@ static int ggml_metal_encode_node(
} break;
case GGML_OP_IM2COL:
{
- GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
@@ -5130,6 +5151,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
} else {
switch (ne00) {
+ case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
@@ -5154,6 +5176,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
} else {
switch (ne00) {
+ case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
@@ -5178,6 +5201,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
} else {
switch (ne00) {
+ case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
@@ -5202,6 +5226,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
} else {
switch (ne00) {
+ case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
@@ -5226,6 +5251,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
} else {
switch (ne00) {
+ case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
@@ -5250,6 +5276,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
} else {
switch (ne00) {
+ case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
@@ -5274,6 +5301,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
} else {
switch (ne00) {
+ case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
@@ -5301,6 +5329,24 @@ static int ggml_metal_encode_node(
use_vec_kernel = true;
switch (ne00) {
+ case 40:
+ {
+ switch (src1->type) {
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40].pipeline; break;
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40].pipeline; break;
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40].pipeline; break;
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40].pipeline; break;
+ default:
+ {
+ GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
+ GGML_LOG_ERROR("add template specialization for this type\n");
+ GGML_ABORT("add template specialization for this type");
+ }
+ }
+ } break;
case 64:
{
switch (src1->type) {
@@ -5465,6 +5511,7 @@ static int ggml_metal_encode_node(
/*.nb33 =*/ nb33,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
/*.scale =*/ scale,
/*.max_bias =*/ max_bias,
/*.m0 =*/ m0,
@@ -5488,7 +5535,6 @@ static int ggml_metal_encode_node(
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
}
- [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
if (!use_vec_kernel) {
// half8x8 kernel
@@ -5514,7 +5560,7 @@ static int ggml_metal_encode_node(
while (true) {
const size_t smem = FATTN_SMEM(nsgmax);
- if (smem > device.maxThreadgroupMemoryLength) {
+ if (smem > device.maxThreadgroupMemoryLength/2) {
break;
}
nsgmax *= 2;
@@ -5526,15 +5572,18 @@ static int ggml_metal_encode_node(
const size_t smem = FATTN_SMEM(nsg);
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
+
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:smem atIndex:0];
-#undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+#undef FATTN_SMEM
} else {
// half4x4 kernel
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+ const int64_t nkpsg = 1*ncpsg; // TODO: make adjustable
GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 1 == 0);
@@ -5544,15 +5593,17 @@ static int ggml_metal_encode_node(
// for each query, we load it as f16 in shared memory (ne00)
// and store the soft_max values and the mask
//
- // ne00*(nsg)
+ // ne20*(nsg)
// each simdgroup has a full f32 head vector in shared mem to accumulate results
//
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
+//#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)))*(sizeof(float)/2), 16))
int64_t nsgmax = 2;
while (true) {
const size_t smem = FATTN_SMEM(nsgmax);
- if (smem > device.maxThreadgroupMemoryLength) {
+ // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
+ if (smem > device.maxThreadgroupMemoryLength/2) {
break;
}
nsgmax *= 2;
@@ -5560,7 +5611,7 @@ static int ggml_metal_encode_node(
nsgmax /= 2;
// simdgroups per threadgroup (a.k.a. warps)
- const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
+ const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
int64_t nsg = 1;
while (nsg <= nsgt) {
@@ -5568,13 +5619,74 @@ static int ggml_metal_encode_node(
}
nsg /= 2;
- const size_t smem = FATTN_SMEM(nsg);
+ // workgroups
+ // each workgroup handles nsg*nkpsg cache values
+ uint16_t nwg = 1;
+ if (4*nsg*nkpsg >= ne11) {
+ const size_t smem = FATTN_SMEM(nsg);
- //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
- GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
- [encoder setThreadgroupMemoryLength:smem atIndex:0];
+ //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
+ GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
+
+ // using 1 workgroup -> write the result directly into dst
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
+ [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
+
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ } else {
+ nwg = 32;
+ nsg = MIN(4, nsg);
+
+ const size_t smem = FATTN_SMEM(nsg);
+
+ //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
+ GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
+
+ // sanity checks
+ GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
+ GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
+
+ const int32_t nrows = ne1*ne2*ne3;
+
+ // temp buffer for writing the results from each workgroup
+ // - ne20: the size of the head vector
+ // - + 2: the S and M values for each intermediate result
+ const size_t s_tmp = ggml_type_size(GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2));
+ id h_tmp = ggml_metal_mem_pool_alloc(mem_pool, s_tmp);
+ if (!h_tmp) {
+ GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tmp);
+ return 0;
+ }
+
+ //printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
+ //printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
+
+ [encoder setBuffer:h_tmp offset:0 atIndex:6];
+ [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
+
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+
+ // reduce the results from the workgroups
+ {
+ ggml_metal_kargs_flash_attn_ext_reduce args0 = {
+ nrows,
+ ne20,
+ };
+
+ id pipeline0 = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline;
+
+ [encoder setComputePipelineState:pipeline0];
+ [encoder setBytes:&args0 length:sizeof(args0) atIndex:0];
+ [encoder setBuffer:h_tmp offset:0 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+
+ //printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*32, 1, 1)];
+ }
+ }
#undef FATTN_SMEM
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
}
} break;
case GGML_OP_DUP:
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index b35a3bbdc3..4fa16c4a55 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -68,6 +68,11 @@ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg)
reg = (type4x4)(*src);
}
+template
+void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) {
+ reg = (type4)(*src);
+}
+
template
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
@@ -974,9 +979,16 @@ kernel void kernel_mul(
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- const int i10 = i0%args.ne10;
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
+ if (args.ne10 == 1) {
+ const float x = *((device float *)(src1_ptr));
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
+ }
+ } else {
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ const int i10 = i0%args.ne10;
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
+ }
}
}
@@ -1000,9 +1012,16 @@ kernel void kernel_div(
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- const int i10 = i0%args.ne10;
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
+ if (args.ne10 == 1) {
+ const float x = 1.0f / *((device float *)(src1_ptr));
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
+ }
+ } else {
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ const int i10 = i0%args.ne10;
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
+ }
}
}
@@ -1964,14 +1983,15 @@ kernel void kernel_ssm_scan_f32(
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
const int64_t i = i0 + i1*nc;
+ const int64_t g = ir / (nh / ng); // repeat_interleave
float s0 = s0_buff[i];
float s = s_buff[i];
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
- device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
- device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
+ device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
+ device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
for (int64_t i2 = 0; i2 < n_t; ++i2) {
@@ -2079,14 +2099,15 @@ kernel void kernel_ssm_scan_f32_group(
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
const int64_t i = i0 + i1*nc;
+ const int64_t g = ir / (nh / ng); // repeat_interleave
float s0 = s0_buff[i];
float s = s_buff[i];
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
- device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
- device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
+ device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
+ device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
for (int64_t i2 = 0; i2 < n_t; ++i2) {
@@ -3001,7 +3022,6 @@ void kernel_mul_mv_ext_q4_f32_impl(
#pragma unroll(r1ptg)
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);
-
}
}
@@ -3186,6 +3206,11 @@ kernel void kernel_mul_mv_ext_q4x4_f32_disp(
typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t;
+template [[host_name("kernel_mul_mv_ext_f32_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, float4, 4, dequantize_f32_t4>;
+template [[host_name("kernel_mul_mv_ext_f32_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, float4, 4, dequantize_f32_t4>;
+template [[host_name("kernel_mul_mv_ext_f32_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, float4, 4, dequantize_f32_t4>;
+template [[host_name("kernel_mul_mv_ext_f32_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, float4, 4, dequantize_f32_t4>;
+
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>;
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>;
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
@@ -4663,6 +4688,7 @@ kernel void kernel_flash_attn_ext(
typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t;
+template [[host_name("kernel_flash_attn_ext_f16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -4674,6 +4700,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_bf16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -4685,6 +4712,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
#endif
+template [[host_name("kernel_flash_attn_ext_q4_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -4695,6 +4723,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -4705,6 +4734,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -4715,6 +4745,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -4725,6 +4756,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -4765,14 +4797,16 @@ kernel void kernel_flash_attn_ext_vec(
device const char * mask,
device const char * sinks,
device char * dst,
+ constant uint16_t & nwg,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups
+ const short iwg = tgpig[2]%nwg;
- const int iq3 = tgpig[2];
+ const int iq3 = tgpig[2]/nwg;
const int iq2 = tgpig[1];
const int iq1 = tgpig[0];
@@ -4851,7 +4885,7 @@ kernel void kernel_flash_attn_ext_vec(
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
- for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
+ for (int ic0 = (int) iwg*C*nsg; ic0 < args.ne11; ic0 += (int) nwg*C*nsg) {
const int ic = ic0 + C*sgitg;
if (ic >= args.ne11) {
break;
@@ -4981,7 +5015,7 @@ kernel void kernel_flash_attn_ext_vec(
}
}
- if (sinks != q && sgitg == 0) {
+ if (sinks != q && sgitg == 0 && iwg == 0) {
const float m = M;
const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
@@ -5090,14 +5124,25 @@ kernel void kernel_flash_attn_ext_vec(
threadgroup_barrier(mem_flags::mem_threadgroup);
}
- device float4 * dst4 = (device float4 *) dst;
-
// final rescale with 1/S and store to global memory
if (sgitg == 0) {
- const float S = ss[0];
+ const int64_t nrows = args.ne3*args.ne2*args.ne1;
+ const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1;
+ device float4 * dst4 = (device float4 *) dst;
+ device float * dst1 = (device float *) dst + nrows*DV*nwg; // the S and M are stored after the results
+
+ const float S = nwg == 1 ? 1.0f/ss[0] : 1.0f;
+
+ // interleave the workgroup data
for (short i = tiisg; i < DV4; i += NW) {
- dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S;
+ dst4[rid*DV4*nwg + nwg*i + iwg] = (float4) sr4[i]*S;
+ }
+
+ // store S and M
+ if (nwg > 1 && tiisg == 0) {
+ dst1[rid*(2*nwg) + 2*iwg + 0] = ss[0];
+ dst1[rid*(2*nwg) + 2*iwg + 1] = ss[1];
}
}
}
@@ -5115,6 +5160,16 @@ kernel void kernel_flash_attn_ext_vec(
typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t;
+template [[host_name("kernel_flash_attn_ext_vec_f16_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+
template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
@@ -5187,6 +5242,41 @@ template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flas
#undef FA_TYPES
+kernel void kernel_flash_attn_ext_reduce(
+ constant ggml_metal_kargs_flash_attn_ext_reduce & args,
+ device const char * htmp,
+ device char * dst,
+ uint tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ const uint64_t rid = tgpig;
+
+ const short nwg = 32;
+ const short iwg = tiisg;
+ const short DV = args.ne20;
+ const short DV4 = DV/4;
+
+ device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*nwg;
+ device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*nwg;
+ device float4 * dst4 = (device float4 *) dst + rid*DV4;
+
+ float S = ss[rid*(2*nwg) + 2*iwg + 0];
+ float M = ss[rid*(2*nwg) + 2*iwg + 1];
+
+ const float m = simd_max(M);
+ const float ms = exp(M - m);
+
+ S = 1.0f/simd_sum(S*ms);
+
+ for (int i = sgitg; i < DV4; i += nwg) {
+ const float4 v = simd_sum(htmp4[i*nwg + iwg]*ms);
+
+ if (iwg == 0) {
+ dst4[i] = v*S;
+ }
+ }
+}
+
template
kernel void kernel_set(
constant ggml_metal_kargs_set & args,
@@ -7474,97 +7564,81 @@ kernel void kernel_mul_mm(
}
}
-template
+template // n_expert_used
kernel void kernel_mul_mm_id_map0(
constant ggml_metal_kargs_mul_mm_id_map0 & args,
- device const char * src1,
device const char * src2,
- device char * hsrc1,
device char * htpe,
device char * hids,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int ide = tgpig[0]; // expert id
+ threadgroup char * shmem [[threadgroup(0)]],
+ ushort tpitg[[thread_position_in_threadgroup]],
+ ushort ntg[[threads_per_threadgroup]]) {
+ const short ide = tpitg; // expert id
- int n_all = 0;
+ uint32_t n_all = 0;
- device int32_t * ids_i32 = (device int32_t *) (hids);
+ device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
- for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens
- device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
+ for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
+ if (i21 + tpitg < args.ne21) {
+ device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
- for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
- if (src2_i32[i20] != ide) {
- continue;
+ threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
+
+ #pragma unroll(ne20)
+ for (short i20 = 0; i20 < ne20; i20++) {
+ sids[i20] = src2_i32[i20];
}
-
- device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
- device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
-
- for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
- hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]);
- }
-
- if (tpitg.x == 0) {
- ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
- }
-
- ++n_all;
}
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (short t = 0; t < ntg; t++) {
+ if (i21 + t >= args.ne21) {
+ break;
+ }
+
+ threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
+
+ short sel = 0;
+ #pragma unroll(ne20)
+ for (short i20 = 0; i20 < ne20; i20++) {
+ sel += (sids[i20] == ide)*(i20 + 1);
+ }
+
+ ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
+
+ n_all += sel > 0;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
}
- if (tpitg.x == 0) {
- device int32_t * tpe_i32 = (device int32_t *) (htpe);
- tpe_i32[ide] = n_all;
- }
+ device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
+ tpe_u32[ide] = n_all;
}
-typedef decltype(kernel_mul_mm_id_map0) kernel_mul_mm_id_map0_t;
+typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
-template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0;
-
-template
-kernel void kernel_mul_mm_id_map1(
- constant ggml_metal_kargs_mul_mm_id_map1 & args,
- device const char * hdst,
- device const char * hids,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int i20 = tgpig[0]; // used expert
- const int i21 = tgpig[1]; // token
-
- device const int32_t * ids_i32 = (device const int32_t *) (hids);
- device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2);
-
- const int id = ids_i32[i21*args.ne20 + i20];
-
- const int ide = id / args.neh1;
- const int idt = id % args.neh1;
-
- device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2);
-
- for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) {
- dst_f32x4[i0] = hdst_f32x4[i0];
- }
-}
-
-typedef decltype(kernel_mul_mm_id_map1) kernel_mul_mm_id_map1_t;
-
-template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template
kernel void kernel_mul_mm_id(
constant ggml_metal_kargs_mul_mm_id & args,
device const char * src0,
device const char * src1,
- device const char * tpe,
+ device const char * htpe,
+ device const char * hids,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup T * sa = (threadgroup T *)(shmem);
@@ -7572,19 +7646,20 @@ kernel void kernel_mul_mm_id(
const int r0 = tgpig.y;
const int r1 = tgpig.x;
- const int im = tgpig.z;
+ const int im = tgpig.z; // expert
- device const int32_t * tpe_i32 = (device const int32_t *) (tpe);
+ device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
+ device const int32_t * ids_i32 = (device const int32_t *) (hids);
- const int neh1 = tpe_i32[im];
+ const int32_t neh1 = tpe_u32[im];
if (r1*BLOCK_SIZE_N >= neh1) {
return;
}
// if this block is of 64x32 shape or smaller
- const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
- const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
+ const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
+ const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
// a thread shouldn't load data outside of the matrix
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
@@ -7600,20 +7675,23 @@ kernel void kernel_mul_mm_id(
short il = (tiitg % THREAD_PER_ROW);
- const int i12 = im%args.neh12;
- const int i13 = im/args.neh12;
+ const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col];
- const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const short i11 = (id % args.ne20) % args.ne11;
+ const short i12 = (id / args.ne20);
+ const short i13 = 0;
+
+ const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
const short offset1 = il/nl;
device const block_q * x = (device const block_q *)(src0
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
- device const half * y = (device const half *)(src1
- + args.nbh13*i13
- + args.nbh12*i12
- + args.nbh11*(r1*BLOCK_SIZE_N + thread_col)
- + args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+ device const float * y = (device const float *)(src1
+ + args.nb13*i13
+ + args.nb12*i12
+ + args.nb11*i11
+ + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
// load data and store to threadgroup memory
@@ -7629,7 +7707,7 @@ kernel void kernel_mul_mm_id(
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
}
- *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y);
+ *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device float2x4 *) y));
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
@@ -7665,43 +7743,38 @@ kernel void kernel_mul_mm_id(
}
}
- if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) {
- device float * C = (device float *) dst +
- (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
- (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0;
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- for (short i = 0; i < 8; i++) {
- simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0);
- }
- } else {
- // block is smaller than 64x32, we should avoid writing data outside of the matrix
- threadgroup_barrier(mem_flags::mem_threadgroup);
- threadgroup float * temp_str = ((threadgroup float *) shmem) \
- + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
- for (short i = 0; i < 8; i++) {
- simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
+ threadgroup float * temp_str = ((threadgroup float *) shmem) \
+ + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
+
+ #pragma unroll(8)
+ for (short i = 0; i < 8; i++) {
+ simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (short j = sgitg; j < n_cols; j += 4) {
+ const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
+
+ const short ide = id % args.ne20;
+ const short idt = id / args.ne20;
+
+ device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0;
+ device float4 * D4 = (device float4 *) D;
+
+ threadgroup float * C = (threadgroup float *) shmem + (j*BLOCK_SIZE_M);
+ threadgroup float4 * C4 = (threadgroup float4 *) C;
+
+ int i = tiisg;
+ for (; i < n_rows/4; i += 32) {
+ *(D4 + i) = *(C4 + i);
}
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- if (sgitg == 0) {
- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
- device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0;
- device float4 * D4 = (device float4 *) D;
-
- threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
- threadgroup float4 * C4 = (threadgroup float4 *) C;
-
- int i = 0;
- for (; i < n_rows/4; i++) {
- *(D4 + i) = *(C4 + i);
- }
-
- i *= 4;
- for (; i < n_rows; i++) {
- *(D + i) = *(C + i);
- }
- }
+ i = (4*(n_rows/4)) + tiisg;
+ for (; i < n_rows; i += 32) {
+ *(D + i) = *(C + i);
}
}
}
diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp
index df27501361..c25c2daaf6 100644
--- a/ggml/src/ggml-opencl/ggml-opencl.cpp
+++ b/ggml/src/ggml-opencl/ggml-opencl.cpp
@@ -420,9 +420,9 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_clamp;
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
- cl_kernel kernel_norm;
+ cl_kernel kernel_norm, kernel_norm_mul_add;
cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
- cl_kernel kernel_group_norm;
+ cl_kernel kernel_group_norm, kernel_group_norm_mul_add;
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
cl_kernel kernel_soft_max, kernel_soft_max_4;
cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
@@ -1161,7 +1161,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
backend_ctx->program_norm =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
- CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err));
+ CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err));
+ CL_CHECK((backend_ctx->kernel_norm_mul_add = clCreateKernel(backend_ctx->program_norm, "kernel_norm_mul_add", &err), err));
GGML_LOG_CONT(".");
}
@@ -1487,7 +1488,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
backend_ctx->program_group_norm =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
- CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err));
+ CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err));
+ CL_CHECK((backend_ctx->kernel_group_norm_mul_add = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm_mul_add", &err), err));
GGML_LOG_CONT(".");
}
@@ -2498,12 +2500,47 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
return false;
}
+ } else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) {
+ const ggml_tensor *norm = cgraph->nodes[node_idx];
+ const ggml_tensor *mul = cgraph->nodes[node_idx+1];
+ const ggml_tensor *add = cgraph->nodes[node_idx+2];
+ const ggml_tensor *w = mul->src[0] == norm ? mul->src[1] : mul->src[0];
+ const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0];
+
+ // norm fusion only supports F32
+ if (norm->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) {
+ return false;
+ }
+
+ if (norm->src[0]->ne[0] % 4 != 0) {
+ return false;
+ }
+
+ if (!ggml_is_contiguous(norm->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) {
+ return false;
+ }
+ } else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_GROUP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) {
+ const ggml_tensor *gn = cgraph->nodes[node_idx];
+ const ggml_tensor *mul = cgraph->nodes[node_idx+1];
+ const ggml_tensor *add = cgraph->nodes[node_idx+2];
+ const ggml_tensor *w = mul->src[0] == gn ? mul->src[1] : mul->src[0];
+ const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0];
+
+ if (gn->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) {
+ return false;
+ }
+
+ if (!ggml_is_contiguous(gn->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) {
+ return false;
+ }
}
return true;
}
static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor);
+static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);
+static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);
static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
@@ -2520,6 +2557,16 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
continue;
}
+ if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
+ ggml_opencl_op_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
+ i += 2;
+ continue;
+ }
+ if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_GROUP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
+ ggml_opencl_op_group_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
+ i += 2;
+ continue;
+ }
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]);
i++;
@@ -2647,8 +2694,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SOFT_MAX:
case GGML_OP_NORM:
- case GGML_OP_RMS_NORM:
return true;
+ case GGML_OP_RMS_NORM:
+ return op->ne[0] % 4 == 0 && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_REPEAT:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
case GGML_OP_PAD:
@@ -5038,6 +5086,140 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor *
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
+static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {
+ GGML_ASSERT(norm_tensor && mul_tensor && add_tensor);
+
+ const ggml_tensor * src0 = norm_tensor->src[0];
+ const ggml_tensor * src1 = mul_tensor->src[0] == norm_tensor ? mul_tensor->src[1] : mul_tensor->src[0];
+ const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0];
+ const ggml_tensor * dst = add_tensor;
+
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+ ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
+ cl_ulong offset1 = extra1->offset + src1->view_offs;
+ cl_ulong offset2 = extra2->offset + src2->view_offs;
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+ float eps;
+ memcpy(&eps, norm_tensor->op_params, sizeof(float));
+
+ const int ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3];
+ const cl_ulong nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3];
+ const int ne10 = src1->ne[0], ne11 = src1->ne[1], ne12 = src1->ne[2], ne13 = src1->ne[3];
+ const cl_ulong nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3];
+ const int ne20 = src2->ne[0], ne21 = src2->ne[1], ne22 = src2->ne[2], ne23 = src2->ne[3];
+ const cl_ulong nb21 = src2->nb[1], nb22 = src2->nb[2], nb23 = src2->nb[3];
+ const cl_ulong nbd1 = dst->nb[1], nbd2 = dst->nb[2], nbd3 = dst->nb[3];
+
+ size_t sgs;
+ if (backend_ctx->gpu_family == ADRENO) sgs = 64;
+ else if (backend_ctx->gpu_family == INTEL) sgs = 32;
+ else GGML_ASSERT(false && "Unsupported GPU");
+
+ cl_kernel kernel = backend_ctx->kernel_norm_mul_add;
+
+ int nth = sgs;
+ int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
+ while (nth < ne00/4 && nth < max_workgroup_size) nth *= 2;
+ nth = MIN(nth, max_workgroup_size);
+ nth = MIN(nth, ne00/4);
+
+ size_t gws[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
+ size_t lws[] = {(size_t)nth, 1, 1};
+ size_t num_subgroups = (nth + sgs - 1) / sgs;
+
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02));
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne03));
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb01));
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb02));
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb03));
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne10));
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne11));
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne12));
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne13));
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne20));
+ CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne21));
+ CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne22));
+ CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne23));
+ CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb21));
+ CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb22));
+ CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb23));
+ CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nbd1));
+ CL_CHECK(clSetKernelArg(kernel, 30, sizeof(cl_ulong), &nbd2));
+ CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_ulong), &nbd3));
+ CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &eps));
+ CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_float2) * num_subgroups, NULL));
+
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, gws, lws, dst);
+}
+
+static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {
+ GGML_ASSERT(gn_tensor && mul_tensor && add_tensor);
+
+ const ggml_tensor * src0 = gn_tensor->src[0];
+ const ggml_tensor * src1 = mul_tensor->src[0] == gn_tensor ? mul_tensor->src[1] : mul_tensor->src[0];
+ const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0];
+ const ggml_tensor * dst = add_tensor;
+
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+ ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
+ cl_ulong offset1 = extra1->offset + src1->view_offs;
+ cl_ulong offset2 = extra2->offset + src2->view_offs;
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+ int groups;
+ float eps;
+ memcpy(&groups, gn_tensor->op_params, sizeof(int));
+ memcpy(&eps, (char *)gn_tensor->op_params + sizeof(int), sizeof(float));
+
+ cl_kernel kernel = backend_ctx->kernel_group_norm_mul_add;
+ int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
+ int ne = ggml_nelements(src0);
+ int group_size = ne / groups;
+
+ size_t lws[] = { (size_t)MIN(max_workgroup_size, group_size) };
+ size_t gws[] = { (size_t)groups * lws[0] };
+
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne));
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &group_size));
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &eps));
+
+ backend_ctx->enqueue_ndrange_kernel(kernel, 1, gws, lws, dst);
+}
+
static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
diff --git a/ggml/src/ggml-opencl/kernels/group_norm.cl b/ggml/src/ggml-opencl/kernels/group_norm.cl
index 57c9df4d35..8e4fa0ed12 100644
--- a/ggml/src/ggml-opencl/kernels/group_norm.cl
+++ b/ggml/src/ggml-opencl/kernels/group_norm.cl
@@ -70,3 +70,52 @@ kernel void kernel_group_norm(
dst[j] *= scale;
}
}
+
+//------------------------------------------------------------------------------
+// group_norm_mul_add
+//------------------------------------------------------------------------------
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_32
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_group_norm_mul_add(
+ global float * src0, ulong offset0,
+ global float * src1, ulong offset1,
+ global float * src2, ulong offset2,
+ global float * dst, ulong offsetd,
+ int ne,
+ int group_size,
+ float eps
+) {
+ src0 = (global float *)((global char *)src0 + offset0);
+ src1 = (global float *)((global char *)src1 + offset1);
+ src2 = (global float *)((global char *)src2 + offset2);
+ dst = (global float *)((global char *)dst + offsetd);
+
+ int start = get_group_id(0) * group_size;
+ int end = start + group_size;
+ if (end > ne) {
+ end = ne;
+ }
+
+ float sum = 0.0f;
+ float sum_sq = 0.0f;
+
+ for (int j = start + get_local_id(0); j < end; j += get_local_size(0)) {
+ float val = src0[j];
+ sum += val;
+ sum_sq += val*val;
+ }
+
+ sum = sub_group_reduce_add(sum);
+ sum_sq = sub_group_reduce_add(sum_sq);
+
+ const float mean = sum / group_size;
+ const float var = sum_sq / group_size - mean * mean;
+ const float scale = rsqrt(var + eps);
+
+ for (int j = start + get_local_id(0); j < end; j += get_local_size(0)) {
+ dst[j] = ((src0[j] - mean) * scale) * src1[j] + src2[j];
+ }
+}
diff --git a/ggml/src/ggml-opencl/kernels/norm.cl b/ggml/src/ggml-opencl/kernels/norm.cl
index 43167ba4d2..170f822787 100644
--- a/ggml/src/ggml-opencl/kernels/norm.cl
+++ b/ggml/src/ggml-opencl/kernels/norm.cl
@@ -79,3 +79,83 @@ kernel void kernel_norm(
y[i00] = y[i00] * scale;
}
}
+
+//------------------------------------------------------------------------------
+// norm_mul_add
+//------------------------------------------------------------------------------
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_32
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_norm_mul_add(
+ global char * src0_ptr, ulong src0_offset,
+ global char * src1_ptr, ulong src1_offset,
+ global char * src2_ptr, ulong src2_offset,
+ global char * dst_ptr, ulong dst_offset,
+ int ne00, int ne01, int ne02, int ne03,
+ ulong nb01, ulong nb02, ulong nb03,
+ int ne10, int ne11, int ne12, int ne13,
+ ulong nb11, ulong nb12, ulong nb13,
+ int ne20, int ne21, int ne22, int ne23,
+ ulong nb21, ulong nb22, ulong nb23,
+ ulong nbd1, ulong nbd2, ulong nbd3,
+ float eps,
+ local float2 * sums
+) {
+ const int i03 = get_group_id(2);
+ const int i02 = get_group_id(1);
+ const int i01 = get_group_id(0);
+
+ global float4 * x = (global float4 *)(src0_ptr + src0_offset + i01*nb01 + i02*nb02 + i03*nb03);
+ global float4 * w = (global float4 *)(src1_ptr + src1_offset + (i01%ne11)*nb11 + (i02%ne12)*nb12 + (i03%ne13)*nb13);
+ global float4 * b = (global float4 *)(src2_ptr + src2_offset + (i01%ne21)*nb21 + (i02%ne22)*nb22 + (i03%ne23)*nb23);
+ global float4 * y = (global float4 *)(dst_ptr + dst_offset + i01*nbd1 + i02*nbd2 + i03*nbd3);
+
+ float p_sum = 0.0f;
+ float p_sum_sq = 0.0f;
+
+ const int n_chunks = ne00 / 4;
+ for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) {
+ float4 val = x[i00];
+ p_sum += val.x + val.y + val.z + val.w;
+ p_sum_sq += dot(val, val);
+ }
+
+ p_sum = sub_group_reduce_add(p_sum);
+ p_sum_sq = sub_group_reduce_add(p_sum_sq);
+
+ if (get_sub_group_local_id() == 0) {
+ sums[get_sub_group_id()] = (float2)(p_sum, p_sum_sq);
+ }
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ if (get_local_id(0) == 0) {
+ float sum = 0.0f;
+ float sum_sq = 0.0f;
+ for (uint i = 0; i < get_num_sub_groups(); ++i) {
+ float2 s = sums[i];
+ sum += s.x;
+ sum_sq += s.y;
+ }
+
+ const float inv_ne00 = 1.0f / (float)ne00;
+ const float mean = sum * inv_ne00;
+ const float variance = mad(-mean, mean, sum_sq * inv_ne00);
+
+ sums[0] = (float2)(mean, rsqrt(variance + eps));
+ }
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ const float2 mean_scale = sums[0];
+ const float mean = mean_scale.x;
+ const float scale = mean_scale.y;
+ const float neg_mean_scale = -mean * scale;
+
+ for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) {
+ const int w_idx = ne10 > 1 ? i00 : 0;
+ const int b_idx = ne20 > 1 ? i00 : 0;
+ const float4 norm_x = mad(x[i00], (float4)scale, (float4)neg_mean_scale);
+ y[i00] = mad(norm_x, w[w_idx], b[b_idx]);
+ }
+}
diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp
index a0a650e92e..18ff4e0b0c 100644
--- a/ggml/src/ggml-sycl/ggml-sycl.cpp
+++ b/ggml/src/ggml-sycl/ggml-sycl.cpp
@@ -4364,11 +4364,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
#endif
case GGML_OP_NORM:
- case GGML_OP_RMS_NORM:
return true;
case GGML_OP_L2_NORM:
case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
+ case GGML_OP_RMS_NORM:
+ return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
case GGML_OP_SCALE:
return true;
case GGML_OP_CONT:
@@ -4391,10 +4392,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return true;
case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
- case GGML_OP_POOL_2D:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT:
+ return ggml_is_contiguous(op->src[0]);
+ case GGML_OP_POOL_2D:
case GGML_OP_ACC:
case GGML_OP_PAD:
case GGML_OP_LEAKY_RELU:
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index fb18a55cda..40962de508 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -102,9 +102,9 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
struct ggml_backend_vk_context;
-#define MAX_PARAMETER_COUNT 8
+#define MAX_PARAMETER_COUNT 12
// Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT.
-#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 2)
+#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3)
struct vk_pipeline_struct {
std::string name;
@@ -115,6 +115,8 @@ struct vk_pipeline_struct {
uint32_t parameter_count;
std::array wg_denoms;
uint32_t align;
+ // true if fields have been set by ggml_vk_create_pipeline
+ bool initialized {};
// set to true to request the pipeline is compiled after the dryrun
bool needed {};
// set to true when the shader has been compiled
@@ -227,21 +229,6 @@ enum vk_device_architecture {
NVIDIA_PRE_TURING,
};
-// HSK x HSV
-enum FaHeadSizes {
- FA_HEAD_SIZE_64,
- FA_HEAD_SIZE_80,
- FA_HEAD_SIZE_96,
- FA_HEAD_SIZE_112,
- FA_HEAD_SIZE_128,
- FA_HEAD_SIZE_192,
- FA_HEAD_SIZE_192_128,
- FA_HEAD_SIZE_256,
- FA_HEAD_SIZE_576_512,
- FA_HEAD_SIZE_UNSUPPORTED,
- FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED,
-};
-
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
vk::PhysicalDeviceProperties props = device.getProperties();
@@ -351,6 +338,28 @@ enum dmmv_wg_sizes {
DMMV_WG_SIZE_COUNT,
};
+enum FaCodePath {
+ FA_SCALAR,
+ FA_COOPMAT1,
+ FA_COOPMAT2,
+};
+
+struct vk_fa_pipeline_state {
+ vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc)
+ : HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {}
+
+ uint32_t HSK, HSV;
+ bool small_rows;
+ FaCodePath path;
+ bool aligned;
+ bool f32acc;
+
+ bool operator<(const vk_fa_pipeline_state &b) const {
+ return std::tie(HSK, HSV, small_rows, path, aligned, f32acc) <
+ std::tie(b.HSK, b.HSV, b.small_rows, b.path, b.aligned, b.f32acc);
+ }
+};
+
static constexpr uint32_t num_argsort_pipelines = 11;
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
@@ -379,8 +388,12 @@ struct vk_device_struct {
bool float_controls_rte_fp16;
bool subgroup_add;
bool subgroup_shuffle;
+ bool subgroup_ballot;
bool multi_add;
+ bool add_rms_fusion;
+ uint32_t partials_binding_alignment;
+
bool integer_dot_product;
bool subgroup_size_control;
@@ -460,9 +473,12 @@ struct vk_device_struct {
vk_pipeline pipeline_mul_norepeat[2][2][2];
vk_pipeline pipeline_div[2][2][2];
vk_pipeline pipeline_div_norepeat[2][2][2];
+ vk_pipeline pipeline_add_rms[2][2][2];
+ vk_pipeline pipeline_add_rms_norepeat[2][2][2];
// indexed by num_additional_fused_ops == num_adds - 1
vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS];
+ vk_pipeline pipeline_multi_add_rms[MAX_FUSED_ADDS];
vk_pipeline pipeline_add_id_f32;
@@ -486,6 +502,8 @@ struct vk_device_struct {
vk_pipeline pipeline_group_norm_f32;
vk_pipeline pipeline_rms_norm_f32;
vk_pipeline pipeline_rms_norm_mul_f32;
+ vk_pipeline pipeline_rms_norm_partials_f32;
+ vk_pipeline pipeline_rms_norm_mul_partials_f32;
vk_pipeline pipeline_rms_norm_back_f32;
vk_pipeline pipeline_l2_norm_f32;
@@ -533,16 +551,11 @@ struct vk_device_struct {
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
- // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
- vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
-
- vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
-
- vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
+ std::map pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
vk_pipeline pipeline_flash_attn_split_k_reduce;
- std::unordered_map pipelines;
+ std::vector all_pipelines;
std::vector> pinned_memory;
@@ -573,15 +586,15 @@ struct vk_device_struct {
compute_queue.cmd_pool.destroy(device);
transfer_queue.cmd_pool.destroy(device);
- for (auto& pipeline : pipelines) {
- if (pipeline.second.expired()) {
+ for (auto& pipeline : all_pipelines) {
+ if (pipeline.expired()) {
continue;
}
- vk_pipeline pl = pipeline.second.lock();
+ vk_pipeline pl = pipeline.lock();
ggml_vk_destroy_pipeline(device, pl);
}
- pipelines.clear();
+ all_pipelines.clear();
device.destroyDescriptorSetLayout(dsl);
@@ -823,8 +836,13 @@ struct vk_op_multi_add_push_constants {
uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23;
// strides for srcs+dst
- uint32_t nb[8][4];
+ uint32_t nb[MAX_PARAMETER_COUNT][4];
+
+ uint32_t rms_partials;
};
+// update multi_add.comp if this changes
+static_assert(MAX_PARAMETER_COUNT == 12);
+static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
struct vk_op_add_id_push_constants {
uint32_t ne0;
@@ -1015,6 +1033,39 @@ struct vk_op_upscale_push_constants {
float sf0; float sf1; float sf2; float sf3;
};
+struct vk_op_sum_rows_push_constants
+{
+ uint32_t n_cols;
+ uint32_t ne01, ne02;
+ uint32_t nb01, nb02, nb03;
+ uint32_t nb11, nb12, nb13;
+ float weight;
+ uint32_t misalign_offsets;
+ uint32_t ne0_12mp, ne0_12L;
+ uint32_t ne0_1mp, ne0_1L;
+};
+
+static vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) {
+ uint32_t type_size = (uint32_t)ggml_type_size(src->type);
+ vk_op_sum_rows_push_constants p = {};
+ p.n_cols = (uint32_t)n_cols;
+ p.ne01 = (uint32_t)src->ne[1];
+ p.ne02 = (uint32_t)src->ne[2];
+ p.nb01 = (uint32_t)src->nb[1] / type_size;
+ p.nb02 = (uint32_t)src->nb[2] / type_size;
+ p.nb03 = (uint32_t)src->nb[3] / type_size;
+ p.nb11 = (uint32_t)dst->nb[1] / type_size;
+ p.nb12 = (uint32_t)dst->nb[2] / type_size;
+ p.nb13 = (uint32_t)dst->nb[3] / type_size;
+ p.weight = 1.0f;
+ return p;
+}
+
+template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) {
+ init_fastdiv_values(p.ne01*p.ne02, p.ne0_12mp, p.ne0_12L);
+ init_fastdiv_values(p.ne01, p.ne0_1mp, p.ne0_1L);
+}
+
// Allow pre-recording command buffers
struct vk_staging_memcpy {
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -1175,6 +1226,12 @@ class vk_perf_logger {
timings[name].push_back(time);
return;
}
+ if (node->op == GGML_OP_RMS_NORM) {
+ std::string name = ggml_op_name(node->op);
+ name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
+ timings[name].push_back(time);
+ return;
+ }
timings[ggml_op_name(node->op)].push_back(time);
}
private:
@@ -1189,15 +1246,26 @@ struct ggml_backend_vk_context {
size_t semaphore_idx, event_idx;
ggml_vk_garbage_collector gc;
- size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k;
- vk_buffer prealloc_x, prealloc_y, prealloc_split_k;
+ size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset;
+ vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials;
vk::Fence fence, almost_ready_fence;
bool almost_ready_fence_pending {};
+ // Set before op_add and unset after op_rms_norm to indicate that the add should
+ // write partial sums to accumulate the square of the vector components
+ bool do_add_rms_partials;
// Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert.
vk_pipeline_struct * prealloc_y_last_pipeline_used {};
const ggml_tensor * prealloc_y_last_tensor_used {};
+ // Track which nodes have been used since the last sync, and whether they were written to
+ std::vector unsynced_nodes_written;
+ std::vector unsynced_nodes_read;
+ // Track which prealloc buffers have pending reads that need to be synchronized.
+ // These are checked before writing to the buffer (and call ggml_vk_sync_buffers if set),
+ // and set to true after the buffer contents are consumed.
+ bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
+
vk_buffer buffer_pool[MAX_VK_BUFFERS];
vk_context_ref compute_ctx;
@@ -1436,7 +1504,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
{
std::lock_guard guard(device->mutex);
- device->pipelines.insert({ pipeline->name, pipeline });
+ device->all_pipelines.push_back(pipeline);
}
{
@@ -1873,14 +1941,18 @@ static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) {
return { buf, 0, VK_WHOLE_SIZE };
}
-static void ggml_vk_sync_buffers(vk_context& ctx) {
+static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) {
VK_LOG_DEBUG("ggml_vk_sync_buffers()");
- const bool transfer_queue = ctx->p->q->transfer_only;
+ const bool transfer_queue = subctx->p->q->transfer_only;
- ctx->s->buffer.pipelineBarrier(
- ctx->p->q->stage_flags,
- ctx->p->q->stage_flags,
+ if (ctx) {
+ ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
+ }
+
+ subctx->s->buffer.pipelineBarrier(
+ subctx->p->q->stage_flags,
+ subctx->p->q->stage_flags,
{},
{ {
{ !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) },
@@ -1907,47 +1979,12 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events
);
}
-enum FaCodePath {
- FA_SCALAR,
- FA_COOPMAT1,
- FA_COOPMAT2,
-};
-
-static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
- if (hsk != 192 && hsk != 576 && hsk != hsv) {
- return FA_HEAD_SIZE_UNSUPPORTED;
- }
- switch (hsk) {
- case 64: return FA_HEAD_SIZE_64;
- case 80: return FA_HEAD_SIZE_80;
- case 96: return FA_HEAD_SIZE_96;
- case 112: return FA_HEAD_SIZE_112;
- case 128: return FA_HEAD_SIZE_128;
- case 192:
- if (hsv == 192) {
- return FA_HEAD_SIZE_192;
- } else if (hsv == 128) {
- return FA_HEAD_SIZE_192_128;
- } else {
- return FA_HEAD_SIZE_UNSUPPORTED;
- }
- case 256: return FA_HEAD_SIZE_256;
- case 576:
- if (hsv == 512) {
- return FA_HEAD_SIZE_576_512;
- } else {
- return FA_HEAD_SIZE_UNSUPPORTED;
- }
- default: return FA_HEAD_SIZE_UNSUPPORTED;
- }
-}
-
// number of rows/cols for flash attention shader
static constexpr uint32_t flash_attention_num_small_rows = 32;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
- if (hsv >= 512) {
+ if (hsv >= 192) {
return 2;
} else {
return 8;
@@ -1977,7 +2014,13 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
if (small_rows) {
return {scalar_flash_attention_num_small_rows, 64};
} else {
- return {get_fa_scalar_num_large_rows(hsv), 32};
+ if ((hsv | hsk) & 8) {
+ // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
+ // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
+ return {get_fa_scalar_num_large_rows(hsv), 64};
+ } else {
+ return {get_fa_scalar_num_large_rows(hsv), 32};
+ }
}
}
@@ -1995,8 +2038,8 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
}
// small cols to reduce register count
- if (ggml_is_quantized(type) || hsk >= 256) {
- if (hsk >= 512) {
+ if (ggml_is_quantized(type) || hsk >= 256 || hsv >= 256) {
+ if (hsk >= 512 || hsv >= 512) {
return {32, 32};
} else {
return {64, 32};
@@ -2005,6 +2048,10 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
return {64, 64};
}
+static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) {
+ return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1];
+}
+
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) {
uint32_t lut_size = 0;
@@ -2043,10 +2090,11 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
const uint32_t warps = warptile[0] / warptile[10];
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
- const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
+ const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0;
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
+ const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0;
- const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
+ const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
@@ -2130,8 +2178,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);
+ const uint32_t mul_mat_subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
+ const uint32_t mul_mat_subgroup_size_8 = std::max(mul_mat_subgroup_size, 8u);
+ const uint32_t mul_mat_subgroup_size_16 = std::max(mul_mat_subgroup_size, 16u);
+ const uint32_t mul_mat_subgroup_size_32 = std::max(mul_mat_subgroup_size, 32u);
+
+ const bool subgroup_min_size_16 = (!device->subgroup_size_control && device->subgroup_size >= 16) ||
+ (device->subgroup_size_control && device->subgroup_max_size >= 16);
+
// mulmat
std::vector l_warptile, m_warptile, s_warptile,
+ l_warptile_id, m_warptile_id, s_warptile_id,
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
@@ -2168,9 +2225,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
s_mmq_wg_denoms_k = { 32, 64, 1 };
// spec constants and tile sizes for quant matmul_id
- l_warptile_mmqid = { 256, 128, 128, 16, 0 };
- m_warptile_mmqid = { 256, 128, 64, 16, 0 };
- s_warptile_mmqid = { 256, 128, 64, 16, 0 };
+ l_warptile_mmqid = { 256, 128, 128, 16, 0, device->subgroup_size };
+ m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
+ s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
l_mmqid_wg_denoms = { 128, 128, 1 };
m_mmqid_wg_denoms = { 128, 64, 1 };
s_mmqid_wg_denoms = { 128, 64, 1 };
@@ -2202,9 +2259,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
+ l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
+ m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
+ s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
+
+ l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
+ m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
+ s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
+
// chip specific tuning
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
+ m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
}
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
@@ -2230,14 +2296,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
// Disable mul_mat_id if not enough shared memory is available
- if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true, t)) {
+ if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t)) {
device->mul_mat_id_s[i] = false;
device->mul_mat_id_m[i] = false;
device->mul_mat_id_l[i] = false;
- } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true, t)) {
+ } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t)) {
device->mul_mat_id_m[i] = false;
device->mul_mat_id_l[i] = false;
- } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true, t)) {
+ } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) {
device->mul_mat_id_l[i] = false;
}
}
@@ -2270,11 +2336,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (!pipeline) {
pipeline = std::make_shared();
+ }
+ if (!pipeline->initialized) {
pipeline->name = name;
pipeline->parameter_count = parameter_count;
pipeline->push_constant_size = push_constant_size;
pipeline->wg_denoms = wg_denoms;
pipeline->align = align;
+ pipeline->initialized = true;
}
if (!pipeline->needed || pipeline->compiled) {
@@ -2320,26 +2389,30 @@ static void ggml_vk_load_shaders(vk_device& device) {
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
};
-#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
-
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512)
+ for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \
+ uint32_t HSK = fa.first.HSK; \
+ uint32_t HSV = fa.first.HSV; \
+ bool small_rows = fa.first.small_rows; \
+ FaCodePath path = fa.first.path; \
+ bool aligned = fa.first.aligned; \
+ bool f32acc = fa.first.f32acc; \
+ if (path == FAPATH) { \
+ if (aligned) { \
+ if (f32acc) { \
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
+ } else { \
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
+ } \
+ } else { \
+ if (f32acc) { \
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
+ } else { \
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
+ } \
+ } \
+ } \
+ }
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
@@ -2362,7 +2435,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
}
#endif
-#undef CREATE_FA2
#undef CREATE_FA
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
@@ -2409,32 +2481,34 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
- CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
+ GGML_ASSERT(device->subgroup_ballot);
+
+ CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
if (device->coopmat_bf16_support) {
- CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
}
#endif
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
#undef CREATE_MM
#undef CREATE_MM2
} else
@@ -2521,55 +2595,56 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
}
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+ GGML_ASSERT(device->subgroup_ballot);
+
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
if (device->coopmat_bf16_support) {
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
}
#endif
- CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-
- CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
#undef CREATE_MM2
#undef CREATE_MM
} else
#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->fp16) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
-#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
+#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) { \
@@ -2586,38 +2661,38 @@ static void ggml_vk_load_shaders(vk_device& device) {
} \
// Create 2 variants, {f16,f32} accumulator
-#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
- CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
- CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
+#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
+ CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
+ CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
- CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+ CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
- CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+ CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
@@ -2629,51 +2704,77 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
#endif
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+ if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ } else {
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-
- CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ }
#undef CREATE_MM2
#undef CREATE_MMQ
#undef CREATE_MM
} else {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
-#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
+#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, REQSUBGROUPSIZE > 0, false, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, REQSUBGROUPSIZE > 0, false, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, REQSUBGROUPSIZE > 0, false, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
@@ -2683,34 +2784,34 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
- CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
- CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
@@ -2722,33 +2823,59 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
#endif
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+ if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ } else {
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-
- CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ }
}
// reusing CREATE_MM from the fp32 path
if ((device->coopmat2 || device->coopmat_support)
@@ -2765,8 +2892,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_wg_denoms = { 64, 64, 1 };
s_wg_denoms = { 32, 32, 1 };
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
}
#undef CREATE_MM
@@ -2942,8 +3069,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
@@ -3013,25 +3144,28 @@ static void ggml_vk_load_shaders(vk_device& device) {
};
bool rte = device->float_controls_rte_fp16;
-#define CREATE_BINARY(name, namemod, spec) \
+#define CREATE_BINARY(name, namemod, spec, bindings) \
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
- "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
+ "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
- CREATE_BINARY(add, , {0})
- CREATE_BINARY(add, _norepeat, {1})
- CREATE_BINARY(sub, , {0})
- CREATE_BINARY(sub, _norepeat, {1})
- CREATE_BINARY(mul, , {0})
- CREATE_BINARY(mul, _norepeat, {1})
- CREATE_BINARY(div, , {0})
- CREATE_BINARY(div, _norepeat, {1})
+ CREATE_BINARY(add, , {0}, 4)
+ CREATE_BINARY(add, _norepeat, {1}, 4)
+ CREATE_BINARY(sub, , {0}, 3)
+ CREATE_BINARY(sub, _norepeat, {1}, 3)
+ CREATE_BINARY(mul, , {0}, 3)
+ CREATE_BINARY(mul, _norepeat, {1}, 3)
+ CREATE_BINARY(div, , {0}, 3)
+ CREATE_BINARY(div, _norepeat, {1}, 3)
+ CREATE_BINARY(add_rms, , {0}, 4)
+ CREATE_BINARY(add_rms, _norepeat, {1}, 4)
#undef CREATE_BINARY
if (device->multi_add) {
for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) {
- ggml_vk_create_pipeline(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_multi_add_rms[i], "multi_add_rms_f32_" + std::to_string(i+1), multi_add_rms_f32_len, multi_add_rms_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
}
}
@@ -3128,7 +3262,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
@@ -3447,6 +3581,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
+ device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
+ (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot);
+
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
@@ -3596,9 +3733,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
(subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) &&
subgroup_size_control_features.subgroupSizeControl;
- if (device->subgroup_size_control) {
- device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
- }
+ device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
#if defined(VK_KHR_cooperative_matrix)
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
@@ -3899,6 +4034,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
+ device->add_rms_fusion = !device->disable_fusion &&
+ device->subgroup_add &&
+ device->vendor_id != VK_VENDOR_ID_INTEL;
+ device->partials_binding_alignment =
+ std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);
+
return device;
}
@@ -4865,7 +5006,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
}
}
- ggml_vk_sync_buffers(subctx);
+ ggml_vk_sync_buffers(ctx, subctx);
subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
return;
}
@@ -4880,7 +5021,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size);
VkBufferCopy buf_copy{ 0, offset, copy_size };
- ggml_vk_sync_buffers(subctx);
+ ggml_vk_sync_buffers(ctx, subctx);
vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
for (uint64_t i3 = 0; i3 < ne3; i3++) {
@@ -4934,7 +5075,7 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
}
}
- ggml_vk_sync_buffers(subctx);
+ ggml_vk_sync_buffers(nullptr, subctx);
subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
return;
}
@@ -4955,7 +5096,7 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
offset,
copy_size};
- ggml_vk_sync_buffers(subctx);
+ ggml_vk_sync_buffers(nullptr, subctx);
vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
if (width == spitch) {
@@ -5035,7 +5176,7 @@ static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
if (buf != nullptr) {
// Memory is pinned, use as staging buffer
- ggml_vk_sync_buffers(subctx);
+ ggml_vk_sync_buffers(nullptr, subctx);
subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices);
return;
@@ -5052,7 +5193,7 @@ static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
vk_buffer& staging_buffer = src->device->sync_staging;
- ggml_vk_sync_buffers(subctx);
+ ggml_vk_sync_buffers(nullptr, subctx);
subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices);
deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);
@@ -5242,13 +5383,16 @@ static void ggml_vk_matmul(
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
uint32_t padded_n) {
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
- ggml_vk_sync_buffers(subctx);
if (split_k == 1) {
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
return;
}
+ if (ctx->prealloc_split_k_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
+
GGML_ASSERT(batch_stride_d == m * n);
// Round the split size up to a multiple of 256 (k-quant alignment)
@@ -5258,9 +5402,10 @@ static void ggml_vk_matmul(
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
// Make sure enough workgroups get assigned for split k to work
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
- ggml_vk_sync_buffers(subctx);
+ ggml_vk_sync_buffers(ctx, subctx);
const std::array pc2 = { (uint32_t)(m * n * batch), split_k };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
+ ctx->prealloc_split_k_need_sync = true;
}
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
@@ -5305,7 +5450,6 @@ static void ggml_vk_matmul_id(
"m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
"batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
"n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
- ggml_vk_sync_buffers(subctx);
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
nei0, nei1, nbi1, ne11, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as });
@@ -5436,8 +5580,8 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
};
init_pushconst_fastdiv(pc);
- ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements);
+ ggml_vk_sync_buffers(ctx, subctx);
}
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
@@ -5455,8 +5599,8 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
- ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array{ne}, { ne, 1, 1 });
+ ggml_vk_sync_buffers(ctx, subctx);
}
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -5651,16 +5795,25 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
GGML_ASSERT(qy_sz == y_sz);
}
+ if (x_non_contig || qx_needs_dequant) {
+ if (ctx->prealloc_x_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
+ }
+
if (x_non_contig) {
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
} else if (qx_needs_dequant) {
const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
- ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
+ ggml_vk_sync_buffers(ctx, subctx);
}
if (y_non_contig) {
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
+ if (ctx->prealloc_y_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
@@ -5669,6 +5822,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (quantize_y) {
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
+ if (ctx->prealloc_y_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
@@ -5695,6 +5851,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
); // NOLINT
+
+ if (x_non_contig || qx_needs_dequant) {
+ ctx->prealloc_x_need_sync = true;
+ }
+ if (y_non_contig || quantize_y) {
+ ctx->prealloc_y_need_sync = true;
+ }
}
static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -5841,6 +6004,12 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
GGML_ASSERT(qy_sz == y_sz);
}
+ if (x_non_contig) {
+ if (ctx->prealloc_x_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
+ }
+
if (x_non_contig) {
GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
@@ -5849,6 +6018,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
+ if (ctx->prealloc_y_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
@@ -5884,10 +6056,16 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
stride_batch_x, stride_batch_y, stride_batch_d,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
- ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{ vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
+
+ if (x_non_contig) {
+ ctx->prealloc_x_need_sync = true;
+ }
+ if (y_non_contig) {
+ ctx->prealloc_y_need_sync = true;
+ }
}
static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -5974,7 +6152,6 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
workgroups_z /= gqa_ratio;
}
- ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, workgroups_z });
}
@@ -6061,7 +6238,6 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
// compute
const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), nb03, nb13, nb23 };
- ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
}
@@ -6112,7 +6288,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t nei0 = ids->ne[0];
const uint64_t nei1 = ids->ne[1];
- GGML_ASSERT(nei0 * nei1 <= 4096);
const uint32_t nbi1 = ids->nb[1];
const uint32_t nbi2 = ids->nb[2];
@@ -6273,17 +6448,26 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
GGML_ASSERT(qy_sz == y_sz);
}
+ if (x_non_contig || qx_needs_dequant) {
+ if (ctx->prealloc_x_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
+ }
+
if (x_non_contig) {
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
} else if (qx_needs_dequant) {
const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
- ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
+ ggml_vk_sync_buffers(ctx, subctx);
}
if (y_non_contig) {
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
+ if (ctx->prealloc_y_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
@@ -6310,6 +6494,13 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
stride_batch_x, stride_batch_y, ne20*ne21,
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
); // NOLINT
+
+ if (x_non_contig || qx_needs_dequant) {
+ ctx->prealloc_x_need_sync = true;
+ }
+ if (y_non_contig) {
+ ctx->prealloc_y_need_sync = true;
+ }
}
static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) {
@@ -6469,6 +6660,12 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
GGML_ASSERT(qy_sz == y_sz);
}
+ if (x_non_contig) {
+ if (ctx->prealloc_x_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
+ }
+
if (x_non_contig) {
GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
@@ -6477,6 +6674,9 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
+ if (ctx->prealloc_y_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
@@ -6505,11 +6705,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
(uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21),
(uint32_t)nei0, (uint32_t)ne11,
};
- ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{ vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 },
vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } },
pc, { groups_x, (uint32_t)nei0, groups_z });
+
+ if (x_non_contig) {
+ ctx->prealloc_x_need_sync = true;
+ }
+ if (y_non_contig) {
+ ctx->prealloc_y_need_sync = true;
+ }
}
static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
@@ -6517,37 +6723,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
} else {
- // Split based on number of ids, to fit in shared memory
- const uint32_t nei0 = (uint32_t)src2->ne[0];
- const uint32_t nei1 = (uint32_t)src2->ne[1];
-
- GGML_ASSERT(nei0 <= 4096);
- const uint32_t split_size = std::min(nei1, 4096u / nei0);
-
- if (split_size == nei1) {
- ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
- } else {
- ggml_tensor src1_copy = *src1;
- ggml_tensor src2_copy = *src2;
- ggml_tensor dst_copy = *dst;
-
- for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
- const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
-
- src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
- src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
- dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
-
- src1_copy.ne[2] = n_tokens;
- src2_copy.ne[1] = n_tokens;
- dst_copy.ne[2] = n_tokens;
-
- ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
- // invalidate cached prealloc_y, can't cache based on the copy of the ggml_tensor
- ctx->prealloc_y_last_pipeline_used = {};
- ctx->prealloc_y_last_tensor_used = nullptr;
- }
- }
+ ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
}
}
@@ -6580,18 +6756,21 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
const uint32_t Br = coopmat1_flash_attention_num_large_rows;
const uint32_t Bc = scalar_flash_attention_Bc;
+ const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16);
+
const uint32_t acctype = f32acc ? 4 : 2;
const uint32_t f16vec4 = 8;
const uint32_t tmpsh = wg_size * sizeof(float);
const uint32_t tmpshv4 = wg_size * 4 * acctype;
- const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4;
+ const uint32_t qstride = hsk_pad / 4 + 2;
+ const uint32_t Qf = Br * qstride * f16vec4;
const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
const uint32_t sfsh = Bc * sfshstride * acctype;
- const uint32_t kshstride = hsk / 4 + 2;
+ const uint32_t kshstride = hsk_pad / 4 + 2;
const uint32_t ksh = Bc * kshstride * f16vec4;
const uint32_t slope = Br * sizeof(float);
@@ -6702,7 +6881,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
workgroups_y /= N;
}
- vk_pipeline *pipelines;
bool small_rows = N <= get_fa_num_small_rows(path);
// coopmat1 does not actually support "small rows" (it needs 16 rows).
@@ -6722,37 +6900,36 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
small_rows = true;
}
- bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
-
- FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]);
-
- switch (path) {
- case FA_SCALAR:
- pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0];
- break;
- case FA_COOPMAT1:
- pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0];
- break;
- case FA_COOPMAT2:
- pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0];
- break;
- default:
- GGML_ASSERT(0);
- }
- assert(pipelines);
-
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
- bool aligned = (KV % pipelines[1]->align) == 0 &&
+ uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
+ bool aligned = (KV % alignment) == 0 &&
// the "aligned" shader variant will forcibly align strides, for performance
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
+ // Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned.
+ if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) {
+ aligned = false;
+ }
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
- vk_pipeline pipeline = pipelines[aligned];
+ bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
+
+ vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc);
+
+ vk_pipeline pipeline = nullptr;
+
+ auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type];
+ auto it = pipelines.find(fa_pipeline_state);
+ if (it != pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ pipelines[fa_pipeline_state] = pipeline = std::make_shared();
+ }
+
assert(pipeline);
uint32_t split_kv = KV;
@@ -6768,7 +6945,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
if (split_k > 1) {
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
// of "align", so recompute split_k based on that.
- split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align);
+ split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
split_k = CEIL_DIV(KV, split_kv);
workgroups_x = split_k;
}
@@ -6892,9 +7069,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
mask_n_head_log2, m0, m1,
gqa_ratio, split_kv, split_k };
- ggml_vk_sync_buffers(subctx);
-
if (split_k > 1) {
+ if (ctx->prealloc_split_k_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
@@ -6910,7 +7089,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// cancel out the divide by wg_denoms[0].
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
- ggml_vk_sync_buffers(subctx);
+ ggml_vk_sync_buffers(ctx, subctx);
const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
{
@@ -6919,6 +7098,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
},
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
+ ctx->prealloc_split_k_need_sync = true;
} else {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
@@ -6961,7 +7141,7 @@ static std::array ggml_vk_get_conv_elements(const ggml_tensor *dst)
return elements;
}
-static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
+static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) {
switch (op) {
case GGML_OP_GET_ROWS:
GGML_ASSERT(src1->type == GGML_TYPE_I32);
@@ -6990,10 +7170,19 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
case GGML_OP_ADD:
{
if (ctx->num_additional_fused_ops > 0) {
- return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
+ if (ctx->do_add_rms_partials) {
+ return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops];
+ } else {
+ return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
+ }
+ }
+ if (ctx->do_add_rms_partials) {
+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;
+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
+ } else {
+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
}
- auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
- return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
}
case GGML_OP_SUB:
{
@@ -7116,7 +7305,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
case GGML_OP_RMS_NORM:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
+ if (ctx->do_add_rms_partials) {
+ return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32;
+ } else {
+ return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
+ }
}
return nullptr;
case GGML_OP_RMS_NORM_BACK:
@@ -7249,6 +7442,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
+ case GGML_OP_MEAN:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_sum_rows_f32;
}
@@ -7387,6 +7581,9 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
case GGML_OP_CONV_2D_DW:
case GGML_OP_IM2COL:
case GGML_OP_SET_ROWS:
+ case GGML_OP_SUM:
+ case GGML_OP_SUM_ROWS:
+ case GGML_OP_MEAN:
return true;
default:
return false;
@@ -7421,6 +7618,16 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src2);
}
+template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+ const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
+ const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
+
+ p.misalign_offsets = (a_offset << 16) | d_offset;
+
+ GGML_UNUSED(src1);
+ GGML_UNUSED(src2);
+}
+
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
@@ -7571,10 +7778,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
if (op_supports_incontiguous) {
- x_sz = ggml_nbytes(src0);
- y_sz = use_src1 ? ggml_nbytes(src1) : 0;
- z_sz = use_src2 ? ggml_nbytes(src2) : 0;
- d_sz = ggml_nbytes(dst);
+ x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0);
+ y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0;
+ z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0;
+ d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst);
if (x_buf_offset + x_sz >= d_X->size) {
x_sz = VK_WHOLE_SIZE;
@@ -7602,6 +7809,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_SUM_ROWS:
+ case GGML_OP_MEAN:
case GGML_OP_ARGMAX:
{
const uint32_t nr = ggml_nrows(src0);
@@ -7614,7 +7822,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
} break;
case GGML_OP_RMS_NORM:
- elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
+ if (ctx->do_add_rms_partials) {
+ // Run one element per thread, 128 threads per workgroup
+ elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };
+ } else {
+ elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
+ }
break;
case GGML_OP_SUM:
@@ -7763,7 +7976,16 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
}
- if (op == GGML_OP_GLU) {
+ if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
+ vk_buffer d_A = ctx->do_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X;
+ size_t a_buf_offset = ctx->do_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0;
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
+ { vk_subbuffer{ d_X, x_buf_offset, x_sz },
+ vk_subbuffer{ d_Y, y_buf_offset, y_sz },
+ vk_subbuffer{ d_D, d_buf_offset, d_sz },
+ vk_subbuffer{ d_A, a_buf_offset, VK_WHOLE_SIZE },
+ }, pc, elements);
+ } else if (op == GGML_OP_GLU) {
// Empty src1 is possible in glu, but the shader needs a buffer
vk_subbuffer subbuf_y;
if (use_src1) {
@@ -7772,7 +7994,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
subbuf_y = { d_X, 0, x_sz };
}
- ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_SOFT_MAX) {
// Empty src1 and src2 is possible in soft_max, but the shader needs a buffer
@@ -7790,7 +8011,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
subbuf_z = { d_X, 0, x_sz };
}
- ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
// Empty src2 is possible in rope, but the shader needs a buffer
@@ -7801,30 +8021,23 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
subbuf_z = { d_X, 0, x_sz };
}
- ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_IM2COL) {
// im2col uses only src1 and dst buffers
- ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_COUNT_EQUAL) {
- ggml_vk_sync_buffers(subctx);
// count_equal assumes that destination buffer is initialized with zeroes
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
- ggml_vk_sync_buffers(subctx);
+ ggml_vk_sync_buffers(ctx, subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_OPT_STEP_SGD) {
// OPT_STEP_SGD works on src0, it does not need dst
- ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
} else if (use_src2) {
- ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (use_src1) {
- ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else {
- ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
}
}
@@ -7873,7 +8086,7 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
const ggml_tensor *tensors[MAX_PARAMETER_COUNT];
uint32_t num_srcs = ctx->num_additional_fused_ops + 2;
uint32_t num_tensors = num_srcs + 1;
- GGML_ASSERT(num_tensors <= MAX_PARAMETER_COUNT);
+ GGML_ASSERT(num_tensors + ctx->do_add_rms_partials <= MAX_PARAMETER_COUNT);
tensors[0] = first_node->src[0];
tensors[1] = first_node->src[1];
@@ -7900,8 +8113,9 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float);
pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float);
}
+ pc.rms_partials = ctx->do_add_rms_partials;
- vk_pipeline pipeline = ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, tensors[0], tensors[1], nullptr, dst, dst->op);
if (pipeline == nullptr) {
std::cerr << "ggml_vulkan: Error: Missing multi_add";
@@ -7939,6 +8153,10 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
buf[i] = buf[0];
offset[i] = 0;
}
+ if (ctx->do_add_rms_partials) {
+ buf[num_tensors] = ctx->prealloc_add_rms_partials;
+ offset[num_tensors] = ctx->prealloc_size_add_rms_partials_offset;
+ }
std::array elements;
@@ -7951,7 +8169,7 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
elements = { ne, 1, 1 };
}
- ggml_vk_sync_buffers(subctx);
+ static_assert(MAX_PARAMETER_COUNT == 12);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE },
@@ -7962,6 +8180,10 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
vk_subbuffer{ buf[5], offset[5], VK_WHOLE_SIZE },
vk_subbuffer{ buf[6], offset[6], VK_WHOLE_SIZE },
vk_subbuffer{ buf[7], offset[7], VK_WHOLE_SIZE },
+ vk_subbuffer{ buf[8], offset[8], VK_WHOLE_SIZE },
+ vk_subbuffer{ buf[9], offset[9], VK_WHOLE_SIZE },
+ vk_subbuffer{ buf[10], offset[10], VK_WHOLE_SIZE },
+ vk_subbuffer{ buf[11], offset[11], VK_WHOLE_SIZE },
}, pc, elements);
}
@@ -7976,7 +8198,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
- 0.0f, 0.0f, 0,
+ 0.0f, 0.0f, ctx->do_add_rms_partials,
}, dryrun);
}
@@ -8064,8 +8286,6 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
}
- ggml_vk_sync_buffers(subctx);
-
vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
@@ -8203,8 +8423,6 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont
ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context;
ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context;
- ggml_vk_sync_buffers(subctx);
-
vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr;
size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0;
bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false;
@@ -8438,19 +8656,39 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
}
+static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
+ const uint32_t ne = (uint32_t)node->ne[0];
+ const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0];
+ const uint32_t num_partials = CEIL_DIV(ne, denom);
+ return num_partials;
+}
+
+static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
+ const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node);
+ const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment);
+ return num_bytes;
+}
+
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
+ uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;
+
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
- op_params[0], 0.0f, 0,
+ op_params[0], 0.0f, (int32_t)param3,
}, dryrun);
+
+ if (ctx->do_add_rms_partials) {
+ ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0);
+ ctx->do_add_rms_partials = false;
+ }
}
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -8588,11 +8826,19 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
}
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
+ vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun);
}
static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
+ vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun);
+}
+
+static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
+ vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
+ p.weight = 1.0f / (float)src0->ne[0];
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun);
}
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -9720,6 +9966,14 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
}
ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
}
+ if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials->size < ctx->prealloc_size_add_rms_partials)) {
+ VK_LOG_MEMORY("ggml_vk_preallocate_buffers(add_partials_size: " << ctx->prealloc_add_rms_partials << ")");
+ // Resize buffer
+ if (ctx->prealloc_add_rms_partials != nullptr) {
+ ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials);
+ }
+ ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_add_rms_partials);
+ }
}
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
@@ -9776,10 +10030,23 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
return false;
}
break;
+ case GGML_OP_ADD:
+ {
+ int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
+ if (next_node_idx < cgraph->n_nodes &&
+ cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
+ cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
+ ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
+ ctx->device->add_rms_fusion) {
+ if (dryrun) {
+ ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
+ }
+ ctx->do_add_rms_partials = true;
+ }
+ } break;
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_GET_ROWS:
- case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ACC:
case GGML_OP_SUB:
@@ -9815,6 +10082,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_ARGSORT:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
+ case GGML_OP_MEAN:
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
@@ -9884,6 +10152,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_ARGSORT:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
+ case GGML_OP_MEAN:
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
@@ -9899,6 +10168,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
// do the only thing needed for the dryrun.
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
+ if (node->op == GGML_OP_RMS_NORM) {
+ ctx->do_add_rms_partials = false;
+ }
return false;
}
default:
@@ -9906,6 +10178,80 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
}
}
+ if (!dryrun) {
+ // This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers
+ // to synchronize them. This handles most "normal" synchronization when computing the graph, and when
+ // there is no auxiliary memory use, it shouldn't be necessary to call ggml_vk_sync_buffers
+ // outside of this logic. When a node uses one of the prealloc buffers for something like
+ // dequantization or split_k, additional synchronization is needed between those passes.
+ bool need_sync = false;
+
+ // Check whether "node" requires synchronization. The node requires synchronization if it
+ // overlaps in memory with another unsynchronized node and at least one of them is a write.
+ // Destination nodes are checked against both the written/read lists. Source nodes are only
+ // checked against the written list. Two nodes overlap in memory if they come from the same
+ // buffer and the tensor or view ranges overlap.
+ auto const &overlaps_unsynced = [&](const ggml_tensor *node, const std::vector &unsynced_nodes) -> bool {
+ if (unsynced_nodes.size() == 0) {
+ return false;
+ }
+ auto n_base = vk_tensor_offset(node) + node->view_offs;
+ auto n_size = ggml_nbytes(node);
+ ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)node->buffer->context;
+ vk_buffer a_buf = a_buf_ctx->dev_buffer;
+ for (auto &other : unsynced_nodes) {
+ ggml_backend_vk_buffer_context * o_buf_ctx = (ggml_backend_vk_buffer_context *)other->buffer->context;
+ vk_buffer o_buf = o_buf_ctx->dev_buffer;
+ if (a_buf == o_buf) {
+ auto o_base = vk_tensor_offset(other) + other->view_offs;
+ auto o_size = ggml_nbytes(other);
+
+ if ((o_base <= n_base && n_base < o_base + o_size) ||
+ (n_base <= o_base && o_base < n_base + n_size)) {
+ return true;
+ }
+ }
+ }
+ return false;
+ };
+
+ // For all fused ops, check if the destination node or any of the source
+ // nodes require synchronization.
+ for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) {
+ const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
+ if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) {
+ need_sync = true;
+ break;
+ }
+ for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
+ if (!cur_node->src[j]) {
+ continue;
+ }
+ if (overlaps_unsynced(cur_node->src[j], ctx->unsynced_nodes_written)) {
+ need_sync = true;
+ break;
+ }
+ }
+ }
+ if (need_sync) {
+ ctx->unsynced_nodes_written.clear();
+ ctx->unsynced_nodes_read.clear();
+ ggml_vk_sync_buffers(ctx, compute_ctx);
+ }
+ // Add the last fused node and all fused source nodes to the unsynchronized list.
+ const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
+ ctx->unsynced_nodes_written.push_back(last_node);
+ for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
+ const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
+ for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
+ if (!cur_node->src[j]) {
+ continue;
+ }
+ ctx->unsynced_nodes_read.push_back(cur_node->src[j]);
+ }
+ }
+ }
+
switch (node->op) {
case GGML_OP_REPEAT:
ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun);
@@ -10087,6 +10433,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_SUM_ROWS:
ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun);
+ break;
+ case GGML_OP_MEAN:
+ ggml_vk_mean(ctx, compute_ctx, src0, node, dryrun);
+
break;
case GGML_OP_ARGMAX:
ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun);
@@ -10246,6 +10596,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_ARGSORT:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
+ case GGML_OP_MEAN:
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
@@ -10364,6 +10715,10 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
ctx->gc.temp_buffers.clear();
ctx->prealloc_y_last_pipeline_used = {};
+ ctx->unsynced_nodes_written.clear();
+ ctx->unsynced_nodes_read.clear();
+ ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
+
ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
@@ -10882,6 +11237,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast(&dul));
}
+ ctx->prealloc_size_add_rms_partials = 0;
+ ctx->prealloc_size_add_rms_partials_offset = 0;
+ ctx->do_add_rms_partials = false;
+
uint64_t total_mat_mul_bytes = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
if (!ctx->device->disable_fusion) {
@@ -10950,6 +11309,19 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->prealloc_y_last_pipeline_used = nullptr;
ctx->prealloc_y_last_tensor_used = nullptr;
+ if (ctx->prealloc_size_add_rms_partials) {
+ if (ctx->compute_ctx.expired()) {
+ compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
+ ctx->compute_ctx = compute_ctx;
+ ggml_vk_ctx_begin(ctx->device, compute_ctx);
+ } else {
+ compute_ctx = ctx->compute_ctx.lock();
+ }
+ // initialize partial sums to zero.
+ ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials);
+ ggml_vk_sync_buffers(ctx, compute_ctx);
+ }
+
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
// (and scaled down based on model size, so smaller models submit earlier).
@@ -11280,8 +11652,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
bool coopmat2 = device->coopmat2;
- FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]);
- if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
+ uint32_t HSK = op->src[1]->ne[0];
+ uint32_t HSV = op->src[2]->ne[0];
+ if ((HSK % 8) != 0 || (HSV % 8) != 0) {
return false;
}
if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) {
@@ -11483,8 +11856,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
+ return true;
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
+ case GGML_OP_MEAN:
+ return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
@@ -11501,14 +11877,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
// Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);
- bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE;
// Channel-contiguous format is not supported yet.
return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32 &&
ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]) &&
- ggml_is_contiguous(op)) && !is_Apple;
+ ggml_is_contiguous(op));
}
default:
return false;
@@ -11903,7 +12278,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
} else if (tensor->op == GGML_OP_CONCAT) {
tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
} else if (tensor->op == GGML_OP_UPSCALE) {
- tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]);
+ tensor_clone = ggml_interpolate(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]);
} else if (tensor->op == GGML_OP_SCALE) {
const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]);
@@ -12043,6 +12418,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_SUM_ROWS) {
tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
+ } else if (tensor->op == GGML_OP_MEAN) {
+ tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_ARGMAX) {
tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_COUNT_EQUAL) {
@@ -12140,11 +12517,9 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph *
if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) {
return;
}
- bool fused_rms_norm_mul = false;
if (ctx->num_additional_fused_ops == 1 &&
tensor->op == GGML_OP_RMS_NORM &&
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
- fused_rms_norm_mul = true;
tensor = cgraph->nodes[tensor_idx + 1];
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/add.comp
index 2b4085c4f8..00cf2dd62f 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/add.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/add.comp
@@ -1,20 +1,34 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
+#if ADD_RMS
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#extension GL_KHR_shader_subgroup_basic : enable
+#endif
#include "types.comp"
#include "generic_binary_head.comp"
const uint num_threads = 256;
+layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];};
+
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+#if ADD_RMS
+// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
+shared FLOAT_TYPE sumsh[num_threads];
+#endif
+
void main() {
uint idx = get_idx();
+ uint orig_idx = idx;
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;
+ FLOAT_TYPE sum_sq = 0;
+
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) {
continue;
@@ -22,8 +36,34 @@ void main() {
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
- data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
+ FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]);
+ sum_sq += sum*sum;
+
+ data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
idx += num_threads;
}
+
+#if ADD_RMS
+ if (p.param3 != 0) {
+ // reduce the sum within each subgroup, then across subgroups
+ const uint NumSubgroups = num_threads / gl_SubgroupSize;
+ sum_sq = subgroupAdd(sum_sq);
+ if (gl_SubgroupInvocationID == 0) {
+ sumsh[gl_SubgroupID] = sum_sq;
+ }
+ barrier();
+ [[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
+ if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
+ sum_sq += sumsh[gl_SubgroupID + s];
+ sumsh[gl_SubgroupID] = sum_sq;
+ }
+ barrier();
+ }
+
+ if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
+ partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
+ }
+ }
+#endif
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
index b57c9dcfc4..f73e17e1fa 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
@@ -9,6 +9,10 @@ layout (constant_id = 4) const uint32_t HSV = 32;
layout (constant_id = 5) const uint32_t Clamp = 0;
layout (constant_id = 6) const uint32_t D_split = 16;
+// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
+const uint32_t HSK_pad = (HSK + 15) & ~15;
+const uint32_t HSV_pad = (HSV + 15) & ~15;
+
layout (push_constant) uniform parameter {
uint32_t N;
uint32_t KV;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
index 81cc3f81fc..97c2a54129 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
@@ -46,14 +46,14 @@ const uint32_t MatBc = 16;
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
-const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
+const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
shared f16vec4 Qf[Br * qstride];
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
shared ACC_TYPE sfsh[Bc * sfshstride];
-const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
+const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4
shared f16vec4 ksh[Bc * kshstride];
shared float slope[Br];
@@ -74,6 +74,21 @@ void main() {
#define tile_row(r) (row_tid * rows_per_thread + (r))
+ // Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK).
+ if ((HSK % 16) != 0) {
+ [[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) {
+ if (i + tid < Br * qstride) {
+ Qf[i + tid] = f16vec4(0);
+ }
+ }
+ [[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
+ if (i + tid < Bc * kshstride) {
+ ksh[i + tid] = f16vec4(0);
+ }
+ }
+ barrier();
+ }
+
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
@@ -151,14 +166,14 @@ void main() {
}
barrier();
- // K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
+ // K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
// This is written transposed in order to allow for N being 8 if implementations need it
coopmat SfMat = coopmat(0);
coopmat KMat;
coopmat QMat;
- for (uint32_t d = 0; d < HSK / 16; ++d) {
+ for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
index b0564ca0bf..77ae5ff01d 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
@@ -104,16 +104,16 @@ void main() {
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
- coopmat Q;
- coopmat Qf16;
+ coopmat Q;
+ coopmat Qf16;
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
- coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK));
+ coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
- Qf16 = coopmat(Q);
+ Qf16 = coopmat(Q);
Qf16 *= float16_t(p.scale);
- coopmat O = coopmat(0);
+ coopmat O = coopmat(0);
coopmat L, M;
@@ -140,10 +140,10 @@ void main() {
coopmat S = coopmat(0);
- coopmat K_T;
+ coopmat K_T;
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
- coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC);
+ coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
S = coopMatMulAdd(Qf16, K_T, S);
if (p.logit_softcap != 0.0f) {
@@ -208,31 +208,31 @@ void main() {
rowsum = coopmat(0.0);
rowsum = coopMatMulAdd(P_A, One, rowsum);
- coopmat V;
+ coopmat