Merge branch 'master' into quantize

This commit is contained in:
Ed Addario 2025-08-30 10:17:45 +01:00
commit 09198c470b
No known key found for this signature in database
GPG Key ID: E7875815A3230993
86 changed files with 4437 additions and 1429 deletions

View File

@ -1106,7 +1106,7 @@ static void common_params_print_completion(common_params_context & ctx_arg) {
printf("\"\n\n");
printf(" case \"$prev\" in\n");
printf(" --model)\n");
printf(" --model|-m)\n");
printf(" COMPREPLY=( $(compgen -f -X '!*.gguf' -- \"$cur\") $(compgen -d -- \"$cur\") )\n");
printf(" return 0\n");
printf(" ;;\n");
@ -2555,7 +2555,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--lora"}, "FNAME",
"path to LoRA adapter (can be repeated to use multiple adapters)",
[](common_params & params, const std::string & value) {
params.lora_adapters.push_back({ std::string(value), 1.0, nullptr });
params.lora_adapters.push_back({ std::string(value), 1.0, "", "", nullptr });
}
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
@ -2563,7 +2563,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--lora-scaled"}, "FNAME", "SCALE",
"path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)",
[](common_params & params, const std::string & fname, const std::string & scale) {
params.lora_adapters.push_back({ fname, std::stof(scale), nullptr });
params.lora_adapters.push_back({ fname, std::stof(scale), "", "", nullptr });
}
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
@ -3538,6 +3538,22 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--fim-qwen-30b-default"},
string_format("use default Qwen 3 Coder 30B A3B Instruct (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF";
params.model.hf_file = "qwen3-coder-30b-a3b-instruct-q8_0.gguf";
params.port = 8012;
params.n_gpu_layers = 99;
params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
params.n_cache_reuse = 256;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{ "--diffusion-steps" }, "N",
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),

View File

@ -622,6 +622,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
default:
throw std::runtime_error("Unknown chat format");
}
@ -2059,6 +2060,94 @@ static void common_chat_parse_granite(common_chat_msg_parser & builder) {
}
}
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
// Parse thinking tags first - this handles the main reasoning content
builder.try_parse_reasoning("<seed:think>", "</seed:think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Parse tool calls - Seed-OSS uses <seed:tool_call> format
static const common_regex tool_call_begin_regex("<seed:tool_call>");
static const common_regex tool_call_end_regex("</seed:tool_call>");
static const common_regex function_regex("<function=([^>]+)>");
static const common_regex param_regex("<parameter=([^>]+)>");
while (auto tool_res = builder.try_find_regex(tool_call_begin_regex)) {
builder.consume_spaces(); // Consume whitespace after <seed:tool_call>
// Look for function call inside tool call, ignore any content before it
if (auto func_res = builder.try_find_regex(function_regex, std::string::npos, false)) {
auto function_name = builder.str(func_res->groups[1]);
// Parse Seed-OSS parameters <parameter=name>value</parameter>
json args = json::object();
// Parse all parameters
while (auto param_res = builder.try_find_regex(param_regex, std::string::npos, false)) {
// again, ignore noise around parameters
auto param_name = builder.str(param_res->groups[1]);
builder.move_to(param_res->groups[0].end);
builder.consume_spaces(); // Consume whitespace after parameter
auto savedPos = builder.pos();
if (auto param_parse = builder.try_find_literal("</parameter>")) {
auto param = param_parse->prelude;
builder.move_to(savedPos);
try {
if (auto param_res = builder.try_consume_json()) {
args[param_name] = param_res->json;
} else {
args[param_name] = param;
}
} catch (json::exception &) {
args[param_name] = param;
}
} else {
throw common_chat_msg_partial_exception("Incomplete tool parameter");
}
}
// Look for closing function tag
auto end_func = builder.try_find_literal("</function>");
if (end_func) {
builder.move_to(end_func->groups[0].end);
builder.consume_spaces(); // Consume whitespace after </function>
// Add the tool call with parsed arguments, but only if we REALLY got the literal
auto eaten_fragment = builder.input().substr(end_func->groups[0].begin, end_func->groups[0].end);
auto funlen = std::string("</function>").length();
if (eaten_fragment.length() >= funlen && eaten_fragment.substr(0, funlen) == std::string("</function>")) {
if (!builder.add_tool_call(function_name, "", args.dump())) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
} else {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
} else {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
// Look for closing tool call tag
if (auto end_tool = builder.try_find_regex(tool_call_end_regex, std::string::npos, false)) {
builder.move_to(end_tool->groups[0].end);
builder.consume_spaces(); // Consume trailing whitespace after tool call
} else {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
} else {
// No function found - don't consume content here, let it be handled at the end
break;
}
}
// Consume any remaining whitespace after all tool call processing
builder.consume_spaces();
auto remaining = builder.consume_rest();
// If there's any non-whitespace content remaining, add it as content
if (!string_strip(remaining).empty()) {
builder.add_content(remaining);
}
}
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.prompt = apply(tmpl, inputs);
@ -2075,8 +2164,62 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
return data;
}
static common_chat_params common_chat_params_init_seed_oss(
const common_chat_template & tmpl,
templates_params & params,
const common_chat_templates_inputs & inputs)
{
common_chat_params data;
data.prompt = apply(tmpl, params);
data.format = COMMON_CHAT_FORMAT_SEED_OSS;
if (string_ends_with(data.prompt, "<seed:think>")) {
if (!inputs.enable_thinking) {
data.prompt += "</seed:think>";
} else {
data.thinking_forced_open = true;
}
}
if (params.tools.is_array() && !params.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(params.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
// Create rule for Seed-OSS function call format
std::string param_rules;
if (parameters.contains("properties")) {
for (const auto & [key, value] : parameters.at("properties").items()) {
param_rules += "\"<parameter=" + key + ">\"" + builder.add_schema(name + "-arg-" + key, value) +
"\"</parameter>\"";
}
}
tool_rules.push_back(builder.add_rule(name + "-call",
"\"<seed:tool_call>\" space \"<function=" + name + ">\" space " +
param_rules +
" \"</function>\" space \"</seed:tool_call>\""));
});
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<seed:tool_call>" });
data.preserved_tokens = {
"<seed:think>", "</seed:think>", "<seed:tool_call>", "</seed:tool_call>",
"<function=", "</function>", "<parameter=", "</parameter>",
};
builder.add_rule("root", string_join(tool_rules, " | "));
});
}
return data;
}
static common_chat_params common_chat_templates_apply_jinja(
const struct common_chat_templates * tmpls,
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs)
{
templates_params params;
@ -2145,6 +2288,11 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_gpt_oss(tmpl, params);
}
// Seed-OSS
if (src.find("<seed:think>") != std::string::npos) {
return common_chat_params_init_seed_oss(tmpl, params, inputs);
}
// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) {
@ -2303,6 +2451,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_GPT_OSS:
common_chat_parse_gpt_oss(builder);
break;
case COMMON_CHAT_FORMAT_SEED_OSS:
common_chat_parse_seed_oss(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}

View File

@ -111,6 +111,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_GRANITE,
COMMON_CHAT_FORMAT_GPT_OSS,
COMMON_CHAT_FORMAT_SEED_OSS,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};

View File

@ -988,7 +988,12 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams;
}
char buf[1024];
la.ptr = lora.get();
llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf));
la.task_name = buf;
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
la.prompt_prefix = buf;
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
}

View File

@ -34,6 +34,9 @@ struct common_adapter_lora_info {
std::string path;
float scale;
std::string task_name;
std::string prompt_prefix;
struct llama_adapter_lora * ptr;
};

View File

@ -72,6 +72,7 @@ class ModelBase:
endianess: gguf.GGUFEndian
use_temp_file: bool
lazy: bool
dry_run: bool
part_names: list[str]
is_safetensors: bool
hparams: dict[str, Any]
@ -111,6 +112,7 @@ class ModelBase:
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
self.use_temp_file = use_temp_file
self.lazy = not eager or (remote_hf_model_id is not None)
self.dry_run = dry_run
self.remote_hf_model_id = remote_hf_model_id
if remote_hf_model_id is not None:
self.is_safetensors = True
@ -1216,6 +1218,55 @@ class TextModel(ModelBase):
raise NotImplementedError("Only MEAN, CLS, and LAST pooling types supported")
self.gguf_writer.add_pooling_type(pooling_type)
def _set_vocab_interns1(self):
tokens: list[str] = []
toktypes: list[int] = []
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
vocab_size = self.hparams.get("vocab_size", len(vocab))
assert max(vocab.values()) < vocab_size
tokpre = self.get_vocab_base_pre(tokenizer)
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
added_vocab = tokenizer.get_added_vocab()
added_tokens_decoder = tokenizer.added_tokens_decoder
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token: str = reverse_vocab[i]
if token in added_vocab:
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
if not added_tokens_decoder[i].normalized:
previous_token = token
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
if previous_token != token:
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
if added_tokens_decoder[i].special or self.does_token_look_special(token):
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
toktypes.append(gguf.TokenType.NORMAL)
tokens.append(token)
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_vocab._set_special_token("bos", 151643)
special_vocab.add_to_gguf(self.gguf_writer)
class MmprojModel(ModelBase):
model_type = ModelType.MMPROJ
@ -2932,7 +2983,8 @@ class Qwen2Model(TextModel):
if "language_model." in name:
name = name.replace("language_model.", "") # for InternVL
if name.startswith("mlp") or name.startswith("multi_modal_projector") \
or name.startswith("vision_model") or name.startswith("audio_tower"):
or name.startswith("vision_model") or name.startswith("audio_tower") \
or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):
# skip vision and audio tensors
return []
yield from super().modify_tensors(data_torch, name, bid)
@ -3109,7 +3161,7 @@ class LLaDAModel(TextModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Ernie4_5_ForCausalLM")
@ModelBase.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM")
class Ernie4_5Model(TextModel):
model_arch = gguf.MODEL_ARCH.ERNIE4_5
@ -3604,6 +3656,19 @@ class Qwen2MoeModel(TextModel):
class Qwen3Model(Qwen2Model):
model_arch = gguf.MODEL_ARCH.QWEN3
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
self.origin_hf_arch = hparams.get('architectures', [None])[0]
def set_vocab(self):
# deal with intern-s1-mini
if self.origin_hf_arch == 'InternS1ForConditionalGeneration':
self._set_vocab_interns1()
return
super().set_vocab()
@ModelBase.register("Qwen3MoeForCausalLM")
class Qwen3MoeModel(Qwen2MoeModel):
@ -3620,73 +3685,7 @@ class Qwen3MoeModel(Qwen2MoeModel):
self._set_vocab_interns1()
return
try:
self._set_vocab_sentencepiece()
except FileNotFoundError:
self._set_vocab_gpt2()
def _set_vocab_interns1(self):
tokens: list[str] = []
toktypes: list[int] = []
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
vocab_size = self.hparams.get("vocab_size", len(vocab))
assert max(vocab.values()) < vocab_size
tokpre = self.get_vocab_base_pre(tokenizer)
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
added_vocab = tokenizer.get_added_vocab()
added_tokens_decoder = tokenizer.added_tokens_decoder
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token: str = reverse_vocab[i]
if token in added_vocab:
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
if not added_tokens_decoder[i].normalized:
previous_token = token
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
if previous_token != token:
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
if added_tokens_decoder[i].special or self.does_token_look_special(token):
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
toktypes.append(gguf.TokenType.NORMAL)
tokens.append(token)
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_tokens_map_file = self.dir_model / 'special_tokens_map.json'
additional_special_tokens = []
if special_tokens_map_file.is_file():
with open(special_tokens_map_file, encoding = 'utf-8') as f:
additional_special_tokens = json.load(f).get('additional_special_tokens', [])
tokenizer_cfg_file = self.dir_model / 'special_tokens_map.json'
if tokenizer_cfg_file.is_file():
with open(tokenizer_cfg_file, encoding = 'utf-8') as f:
added_tokens_decoder = json.load(f).get('added_tokens_decoder', {})
token2ids_map = {data['content'] : int(token) for token, data in added_tokens_decoder.items() if data['special']}
for token in additional_special_tokens:
if token in token2ids_map:
special_vocab._set_special_token(token, token2ids_map[token])
special_vocab._set_special_token('eos', 151645)
special_vocab._set_special_token("bos", 151643)
special_vocab.add_to_gguf(self.gguf_writer)
super().set_vocab()
@ModelBase.register("GPT2LMHeadModel")
@ -4874,11 +4873,35 @@ class NeoBert(BertModel):
@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
class XLMRobertaModel(BertModel):
model_arch = gguf.MODEL_ARCH.BERT
_lora_files = {}
_lora_names = []
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
hparams = kwargs.pop("hparams", None)
if hparams is None:
hparams = ModelBase.load_hparams(dir_model, False)
if lora_names := hparams.get("lora_adaptations"):
self._lora_names = lora_names
self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3
super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs)
self._xlmroberta_tokenizer_init()
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
if self._lora_names:
for name in self._lora_names:
fname = self.add_prefix_to_filename(self.fname_out, f"lora-{name}-")
self._lora_files[name] = gguf.GGUFWriter(fname, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, dry_run=self.dry_run)
return super().generate_extra_tensors()
def set_type(self):
for lora_writer in self._lora_files.values():
lora_writer.add_type(gguf.GGUFType.ADAPTER)
lora_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
super().set_type()
def set_vocab(self):
self._xlmroberta_set_vocab()
@ -4888,13 +4911,62 @@ class XLMRobertaModel(BertModel):
if name.startswith("roberta."):
name = name[8:]
# jina-embeddings-v3
if ".parametrizations." in name:
name = name.replace(".parametrizations.", ".")
if name.endswith(".original"):
name = name[:-9]
# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
if name == "embeddings.position_embeddings.weight":
if self._position_offset is not None:
data_torch = data_torch[self._position_offset:,:]
if name.endswith(".0.lora_A") or name.endswith(".0.lora_B"):
if name.startswith("pooler.dense"):
return []
num_loras = data_torch.size(0)
assert num_loras == len(self._lora_names)
# Split out each LoRA in their own GGUF
for i, lora_writer in enumerate(self._lora_files.values()):
new_name = self.map_tensor_name(name[:-9]) + name[-7:].lower()
data = data_torch[i, :, :]
# Transpose/flip token_embd/types into correct shape
if new_name == "token_embd.weight.lora_b":
data = data.T
elif new_name.startswith("token_types.weight."):
new_name = new_name[:-1] + ("a" if new_name[-1:] == "b" else "b")
lora_writer.add_tensor(new_name, data.float().numpy(), raw_dtype=gguf.GGMLQuantizationType.F32)
return []
return super().modify_tensors(data_torch, name, bid)
def set_gguf_parameters(self):
super().set_gguf_parameters()
# jina-embeddings-v3
if rotary_emb_base := self.hparams.get("rotary_emb_base"):
self.gguf_writer.add_rope_freq_base(rotary_emb_base)
lora_alpha = self.hparams.get("lora_alpha")
if lora_prompt_prefixes := self.hparams.get("task_instructions"):
assert self._lora_files and all(lora_name in lora_prompt_prefixes for lora_name in self._lora_files.keys())
for lora_name, lora_writer in self._lora_files.items():
lora_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, lora_alpha if lora_alpha is not None else 1.0)
lora_writer.add_string(gguf.Keys.Adapter.LORA_TASK_NAME, lora_name)
if lora_prompt_prefixes:
lora_writer.add_string(gguf.Keys.Adapter.LORA_PROMPT_PREFIX, lora_prompt_prefixes[lora_name])
def write(self):
super().write()
for lora_writer in self._lora_files.values():
lora_writer.write_header_to_file()
lora_writer.write_kv_data_to_file()
lora_writer.write_tensors_to_file(progress=True)
lora_writer.close()
@ModelBase.register("GemmaForCausalLM")
class GemmaModel(TextModel):
@ -6257,9 +6329,11 @@ class DeepseekModel(TextModel):
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("DeepseekV2ForCausalLM")
@ModelBase.register("DeepseekV3ForCausalLM")
@ModelBase.register("KimiVLForConditionalGeneration")
@ModelBase.register(
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"KimiVLForConditionalGeneration",
)
class DeepseekV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
@ -7472,9 +7546,13 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
]
# n_group and d_inner are used during reshape_tensors for mamba2
self.d_model = self.find_hparam(["hidden_size", "d_model"])
self.n_group = self.find_hparam(["n_groups"])
self.d_inner = self.find_hparam(["expand"]) * self.d_model
# NOTE: Explicitly include hparam prefix prefix for d_model to
# disambiguate with top-level head_dim
# NOTE 2: If needed for future models, this can be isolated in a method
# to separate the prefix setting and teh keys used
self.d_model = self.find_hparam([f"{self.hparam_prefixes[0]}_head_dim", "hidden_size", "d_model"])
self.n_group = self.find_hparam(["n_groups", "num_groups"])
self.d_inner = self.find_hparam(["expand", "num_heads"]) * self.d_model
def get_attn_layers(self):
# Explicit list of layer type names
@ -7535,12 +7613,12 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
## Mamba mixer params ##
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state", "state_dim", "ssm_state_size"]))
self.gguf_writer.add_ssm_group_count(self.n_group)
self.gguf_writer.add_ssm_inner_size(self.d_inner)
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
# in llama.cpp
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads", "num_heads"]))
## Attention params ##
head_count_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
@ -7567,6 +7645,55 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
Mamba2Model.set_vocab(self)
@ModelBase.register("NemotronHForCausalLM")
class NemotronHModel(GraniteHybridModel):
"""Hybrid mamba2/attention model from NVIDIA"""
model_arch = gguf.MODEL_ARCH.NEMOTRON_H
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Save the top-level head_dim for later
self.head_dim = self.hparams.get("head_dim", self.hparams.get("attention_head_dim"))
assert self.head_dim is not None, "Could not find the attention head dim in config"
# Don't use expand to calculate d_inner
self.d_inner = self.find_hparam(["num_heads"]) * self.d_model
# Update the ssm / attn / mlp layers
# M: Mamba2, *: Attention, -: MLP
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"]
self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "-"]
def get_attn_layers(self):
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
assert len(hybrid_override_pattern) == self.block_count, "Mismatch between hybrid override and num_hidden_layers!"
return [i for i, val in enumerate(hybrid_override_pattern) if val == "*"]
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_key_length(self.head_dim)
self.gguf_writer.add_value_length(self.head_dim)
# Set feed_forward_length
# NOTE: This will trigger an override warning. This is preferrable to
# duplicating all the parent logic
n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"])
self.gguf_writer.add_feed_forward_length([
n_ff if i in self._mlp_layers else 0 for i in range(self.block_count)
])
def set_vocab(self):
super().set_vocab()
# The tokenizer _does_ add a BOS token (via post_processor type
# TemplateProcessing) but does not set add_bos_token to true in the
# config, so we need to explicitly override it here.
self.gguf_writer.add_add_bos_token(True)
@ModelBase.register("BailingMoeForCausalLM")
class BailingMoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.BAILINGMOE
@ -8510,6 +8637,43 @@ class PixtralModel(LlavaVisionModel):
return "mm.2.weight"
return super().map_tensor_name(name, try_suffixes)
@ModelBase.register("KimiVLForConditionalGeneration")
class KimiVLModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
self.hparams_vision["image_size"] = 64 * 14 # for compatibility
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIVL)
self.gguf_writer.add_vision_use_gelu(True)
self.gguf_writer.add_vision_projector_scale_factor(2)
# eps is the same as pytorch's default value
assert self.hparams_vision is not None
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-5))
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
if is_vision_tensor:
if "pos_emb.weight" in name:
data_torch = data_torch.view(data_torch.shape[0] * data_torch.shape[1], data_torch.shape[2])
elif "wqkv" in name:
split_dim = 0 if "weight" in name else -1
wq, wk, wv = data_torch.chunk(3, dim=split_dim)
return [
(self.map_tensor_name(name.replace("wqkv", "wq")), wq),
(self.map_tensor_name(name.replace("wqkv", "wk")), wk),
(self.map_tensor_name(name.replace("wqkv", "wv")), wv)
]
return [(self.map_tensor_name(name), data_torch)]
return [] # skip other tensors
###### CONVERSION LOGIC ######

View File

@ -21,6 +21,8 @@ Function calling is supported for all models (see https://github.com/ggml-org/ll
- Use `--chat-template-file` to override the template when appropriate (see examples below)
- Generic support may consume more tokens and be less efficient than a model's native format.
- Multiple/parallel tool calling is supported on some models but disabled by default, enable it by passing `"parallel_tool_calls": true` in the completion endpoint payload.
<details>
<summary>Show some common templates and which format handler they use</summary>

View File

@ -6,7 +6,7 @@ Download [MiniCPM-V-4](https://huggingface.co/openbmb/MiniCPM-V-4) PyTorch model
### Build llama.cpp
Readme modification time: 20250206
Readme modification time: 20250731
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md)

View File

@ -0,0 +1,47 @@
## MiniCPM-V 4.5
### Prepare models and code
Download [MiniCPM-V-4_5](https://huggingface.co/openbmb/MiniCPM-V-4_5) PyTorch model from huggingface to "MiniCPM-V-4_5" folder.
### Build llama.cpp
Readme modification time: 20250826
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md)
Clone llama.cpp:
```bash
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp
```
Build llama.cpp using `CMake`:
```bash
cmake -B build
cmake --build build --config Release
```
### Usage of MiniCPM-V 4
Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-V-4_5-gguf) by us)
```bash
python ./tools/mtmd/legacy-models/minicpmv-surgery.py -m ../MiniCPM-V-4_5
python ./tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-V-4_5 --minicpmv-projector ../MiniCPM-V-4_5/minicpmv.projector --output-dir ../MiniCPM-V-4_5/ --minicpmv_version 6
python ./convert_hf_to_gguf.py ../MiniCPM-V-4_5/model
# quantize int4 version
./build/bin/llama-quantize ../MiniCPM-V-4_5/model/ggml-model-f16.gguf ../MiniCPM-V-4_5/model/ggml-model-Q4_K_M.gguf Q4_K_M
```
Inference on Linux or Mac
```bash
# run in single-turn mode
./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4_5/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-4_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
# run in conversation mode
./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-4_5/mmproj-model-f16.gguf
```

View File

@ -28,9 +28,40 @@ static std::string ggml_ne_string(const ggml_tensor * t) {
return str;
}
static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
float v;
if (type == GGML_TYPE_F16) {
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
} else if (type == GGML_TYPE_F32) {
v = *(float *) &data[i];
} else if (type == GGML_TYPE_I64) {
v = (float) *(int64_t *) &data[i];
} else if (type == GGML_TYPE_I32) {
v = (float) *(int32_t *) &data[i];
} else if (type == GGML_TYPE_I16) {
v = (float) *(int16_t *) &data[i];
} else if (type == GGML_TYPE_I8) {
v = (float) *(int8_t *) &data[i];
} else {
GGML_ABORT("fatal error");
}
return v;
}
static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) {
GGML_ASSERT(n > 0);
float sum = 0;
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
for (int64_t i1 = 0; i1 < ne[1]; i1++) {
for (int64_t i0 = 0; i0 < ne[0]; i0++) {
const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3);
sum += v;
}
}
}
}
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
LOG(" [\n");
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
@ -50,25 +81,8 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
LOG("..., ");
i0 = ne[0] - n;
}
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
float v;
if (type == GGML_TYPE_F16) {
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
} else if (type == GGML_TYPE_F32) {
v = *(float *) &data[i];
} else if (type == GGML_TYPE_I64) {
v = (float) *(int64_t *) &data[i];
} else if (type == GGML_TYPE_I32) {
v = (float) *(int32_t *) &data[i];
} else if (type == GGML_TYPE_I16) {
v = (float) *(int16_t *) &data[i];
} else if (type == GGML_TYPE_I8) {
v = (float) *(int8_t *) &data[i];
} else {
GGML_ABORT("fatal error");
}
const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3);
LOG("%12.4f", v);
sum += v;
if (i0 < ne[0] - 1) LOG(", ");
}
LOG("],\n");

View File

@ -1,4 +1,5 @@
# Validation functions
MAKEFLAGS += --no-print-directory
define validate_model_path
@if [ -z "$(MODEL_PATH)" ]; then \
echo "Error: MODEL_PATH must be provided either as:"; \
@ -17,6 +18,13 @@ define validate_embedding_model_path
fi
endef
define quantize_model
@CONVERTED_MODEL="$(1)" QUANTIZED_TYPE="$(QUANTIZED_TYPE)" \
TOKEN_EMBD_TYPE="$(TOKEN_EMBD_TYPE)" OUTPUT_TYPE="$(OUTPUT_TYPE)" \
./scripts/utils/quantize.sh "$(1)" "$(QUANTIZED_TYPE)" "$(TOKEN_EMBD_TYPE)" "$(OUTPUT_TYPE)"
@echo "Export the quantized model path to $(2) variable in your environment"
endef
###
### Casual Model targets/recipes
###
@ -29,6 +37,20 @@ causal-convert-model:
METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
./scripts/causal/convert-model.sh
causal-convert-mm-model-bf16: OUTTYPE=bf16
causal-convert-mm-model-bf16: MM_OUTTYPE=f16
causal-convert-mm-model-bf16: causal-convert-mm-model
causal-convert-mm-model:
$(call validate_model_path,causal-convert-mm-model)
@MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(MODEL_PATH)" \
METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
./scripts/causal/convert-model.sh
@MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(MM_OUTTYPE)" MODEL_PATH="$(MODEL_PATH)" \
METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
./scripts/causal/convert-model.sh --mmproj
causal-run-original-model:
$(call validate_model_path,causal-run-original-model)
@MODEL_PATH="$(MODEL_PATH)" ./scripts/causal/run-org-model.py
@ -67,9 +89,15 @@ causal-quantize-Q8_0: causal-quantize-model
causal-quantize-Q4_0: QUANTIZED_TYPE = Q4_0
causal-quantize-Q4_0: causal-quantize-model
# For Quantization Aware Trained (QAT) models in Q4_0 we explicitly set the
# token embedding and output types to Q8_0 instead of the default Q6_K.
causal-quantize-qat-Q4_0: QUANTIZED_TYPE = Q4_0
causal-quantize-qat-Q4_0: TOKEN_EMBD_TYPE = Q8_0
causal-quantize-qat-Q4_0: OUTPUT_TYPE = Q8_0
causal-quantize-qat-Q4_0: causal-quantize-model
causal-quantize-model:
@CONVERTED_MODEL="$(CONVERTED_MODEL)" QUANTIZED_TYPE="$(QUANTIZED_TYPE)" ./scripts/utils/quantize.sh ${CONVERTED_MODEL} ${QUANTIZED_TYPE}
@echo "Export the quantized model path to QUANTIZED_MODEL variable in your environment"
$(call quantize_model,$(CONVERTED_MODEL),QUANTIZED_MODEL)
causal-run-quantized-model:
@QUANTIZED_MODEL="$(QUANTIZED_MODEL)" ./scripts/causal/run-converted-model.sh ${QUANTIZED_MODEL}
@ -117,9 +145,15 @@ embedding-quantize-Q8_0: embedding-quantize-model
embedding-quantize-Q4_0: QUANTIZED_TYPE = Q4_0
embedding-quantize-Q4_0: embedding-quantize-model
# For Quantization Aware Trained (QAT) models in Q4_0 we explicitly set the
# token embedding and output types to Q8_0 instead of the default Q6_K.
embedding-quantize-qat-Q4_0: QUANTIZED_TYPE = Q4_0
embedding-quantize-qat-Q4_0: TOKEN_EMBD_TYPE = Q8_0
embedding-quantize-qat-Q4_0: OUTPUT_TYPE = Q8_0
embedding-quantize-qat-Q4_0: embedding-quantize-model
embedding-quantize-model:
@./scripts/utils/quantize.sh ${CONVERTED_EMBEDDING_MODEL} ${QUANTIZED_TYPE}
@echo "Export the quantized model path to QUANTIZED_EMBEDDING_MODEL variable in your environment"
$(call quantize_model,$(CONVERTED_EMBEDDING_MODEL),QUANTIZED_EMBEDDING_MODEL)
embedding-run-quantized-model:
@./scripts/embedding/run-converted-model.sh ${QUANTIZED_EMBEDDING_MODEL}
@ -144,6 +178,15 @@ perplexity-run:
hf-create-model:
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}"
hf-create-model-dry-run:
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -d
hf-create-model-embedding:
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -e
hf-create-model-embedding-dry-run:
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -e -d
hf-create-model-private:
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -p

View File

@ -137,6 +137,18 @@ Then the quantized model can be run using the following command:
(venv) $ make causal-run-quantized-model
```
### Quantizing QAT (Quantization Aware Training) models
When quantizing to `Q4_0`, the default data type for the token embedding weights
will be `Q6_K`. For models that are going to be uploaded to ggml-org it is
recommended to use `Q8_0` instead for the embeddings and output tensors.
The reason is that although `Q6_K` is smaller in size, it requires more compute
to unpack, which can hurt performance during output generation when the entire
embedding matrix must be dequantized to compute vocabulary logits. `Q8_0`
provides practically full quality with better computational efficiency.
```console
(venv) $ make causal-quantize-qat-Q4_0
```
## Embedding Language Model Conversion
@ -238,6 +250,18 @@ Then the quantized model can be run using the following command:
(venv) $ make embedding-run-quantized-model
```
### Quantizing QAT (Quantization Aware Training) models
When quantizing to `Q4_0`, the default data type for the token embedding weights
will be `Q6_K`. For models that are going to be uploaded to ggml-org it is
recommended to use `Q8_0` instead for the embeddings and output tensors.
The reason is that although `Q6_K` is smaller in size, it requires more compute
to unpack, which can hurt performance during output generation when the entire
embedding matrix must be dequantized to compute vocabulary logits. `Q8_0`
provides practically full quality with better computational efficiency.
```console
(venv) $ make embedding-quantize-qat-Q4_0
```
## Perplexity Evaluation
### Simple perplexity evaluation
@ -285,13 +309,21 @@ For the following targets a `HF_TOKEN` environment variable is required.
This will create a new model repsository on Hugging Face with the specified
model name.
```console
(venv) $ make hf-create-model MODEL_NAME='TestModel' NAMESPACE="danbev"
(venv) $ make hf-create-model MODEL_NAME='TestModel' NAMESPACE="danbev" ORIGINAL_BASE_MODEL="some-base-model"
Repository ID: danbev/TestModel-GGUF
Repository created: https://huggingface.co/danbev/TestModel-GGUF
```
Note that we append a `-GGUF` suffix to the model name to ensure a consistent
naming convention for GGUF models.
An embedding model can be created using the following command:
```console
(venv) $ make hf-create-model-embedding MODEL_NAME='TestEmbeddingModel' NAMESPACE="danbev" ORIGINAL_BASE_MODEL="some-base-model"
```
The only difference is that the model card for an embedding model will be different
with regards to the llama-server command and also how to access/call the embedding
endpoint.
### Upload a GGUF model to model repository
The following target uploads a model to an existing Hugging Face model repository.
```console

View File

@ -112,6 +112,7 @@ int main(int argc, char ** argv) {
ctx_params.no_perf = false;
if (embedding_mode) {
ctx_params.embeddings = true;
ctx_params.pooling_type = LLAMA_POOLING_TYPE_NONE;
ctx_params.n_ubatch = ctx_params.n_batch;
}

View File

@ -1,5 +1,21 @@
#!/bin/bash
set -e
# Parse command line arguments
MMPROJ=""
while [[ $# -gt 0 ]]; do
case $1 in
--mmproj)
MMPROJ="--mmproj"
shift
;;
*)
shift
;;
esac
done
MODEL_NAME="${MODEL_NAME:-$(basename "$MODEL_PATH")}"
OUTPUT_DIR="${OUTPUT_DIR:-../../models}"
TYPE="${OUTTYPE:-f16}"
@ -11,12 +27,20 @@ echo "Model name: ${MODEL_NAME}"
echo "Data type: ${TYPE}"
echo "Converted model path:: ${CONVERTED_MODEL}"
echo "Metadata override: ${METADATA_OVERRIDE}"
python ../../convert_hf_to_gguf.py --verbose \
${MODEL_PATH} \
--outfile ${CONVERTED_MODEL} \
--outtype ${TYPE} \
--metadata "${METADATA_OVERRIDE}"
CMD_ARGS=("python" "../../convert_hf_to_gguf.py" "--verbose")
CMD_ARGS+=("${MODEL_PATH}")
CMD_ARGS+=("--outfile" "${CONVERTED_MODEL}")
CMD_ARGS+=("--outtype" "${TYPE}")
[[ -n "$METADATA_OVERRIDE" ]] && CMD_ARGS+=("--metadata" "${METADATA_OVERRIDE}")
[[ -n "$MMPROJ" ]] && CMD_ARGS+=("${MMPROJ}")
"${CMD_ARGS[@]}"
echo ""
echo "The environment variable CONVERTED_MODEL can be set to this path using:"
echo "export CONVERTED_MODEL=$(realpath ${CONVERTED_MODEL})"
if [[ -n "$MMPROJ" ]]; then
mmproj_file="${OUTPUT_DIR}/mmproj-$(basename "${CONVERTED_MODEL}")"
echo "The mmproj model was created in $(realpath "$mmproj_file")"
fi

View File

@ -0,0 +1,48 @@
---
base_model:
- {base_model}
---
# {model_name} GGUF
Recommended way to run this model:
```sh
llama-server -hf {namespace}/{model_name}-GGUF
```
Then the endpoint can be accessed at http://localhost:8080/embedding, for
example using `curl`:
```console
curl --request POST \
--url http://localhost:8080/embedding \
--header "Content-Type: application/json" \
--data '{{"input": "Hello embeddings"}}' \
--silent
```
Alternatively, the `llama-embedding` command line tool can be used:
```sh
llama-embedding -hf {namespace}/{model_name}-GGUF --verbose-prompt -p "Hello embeddings"
```
#### embd_normalize
When a model uses pooling, or the pooling method is specified using `--pooling`,
the normalization can be controlled by the `embd_normalize` parameter.
The default value is `2` which means that the embeddings are normalized using
the Euclidean norm (L2). Other options are:
* -1 No normalization
* 0 Max absolute
* 1 Taxicab
* 2 Euclidean/L2
* \>2 P-Norm
This can be passed in the request body to `llama-server`, for example:
```sh
--data '{{"input": "Hello embeddings", "embd_normalize": -1}}' \
```
And for `llama-embedding`, by passing `--embd-normalize <value>`, for example:
```sh
llama-embedding -hf {namespace}/{model_name}-GGUF --embd-normalize -1 -p "Hello embeddings"
```

View File

@ -26,21 +26,31 @@ parser.add_argument('--namespace', '-ns', help='Namespace to add the model to',
parser.add_argument('--org-base-model', '-b', help='Original Base model name', default="")
parser.add_argument('--no-card', action='store_true', help='Skip creating model card')
parser.add_argument('--private', '-p', action='store_true', help='Create private model')
parser.add_argument('--embedding', '-e', action='store_true', help='Use embedding model card template')
parser.add_argument('--dry-run', '-d', action='store_true', help='Print repository info and template without creating repository')
args = parser.parse_args()
repo_id = f"{args.namespace}/{args.model_name}-GGUF"
print("Repository ID: ", repo_id)
repo_url = api.create_repo(
repo_id=repo_id,
repo_type="model",
private=args.private,
exist_ok=False
)
repo_url = None
if not args.dry_run:
repo_url = api.create_repo(
repo_id=repo_id,
repo_type="model",
private=args.private,
exist_ok=False
)
if not args.no_card:
template_path = "scripts/readme.md.template"
if args.embedding:
template_path = "scripts/embedding/modelcard.template"
else:
template_path = "scripts/causal/modelcard.template"
print("Template path: ", template_path)
model_card_content = load_template_and_substitute(
template_path,
model_name=args.model_name,
@ -48,16 +58,21 @@ if not args.no_card:
base_model=args.org_base_model,
)
if model_card_content:
api.upload_file(
path_or_fileobj=model_card_content.encode('utf-8'),
path_in_repo="README.md",
repo_id=repo_id
)
print("Model card created successfully.")
if args.dry_run:
print("\nTemplate Content:\n")
print(model_card_content)
else:
print("Failed to create model card.")
if model_card_content:
api.upload_file(
path_or_fileobj=model_card_content.encode('utf-8'),
path_in_repo="README.md",
repo_id=repo_id
)
print("Model card created successfully.")
else:
print("Failed to create model card.")
print(f"Repository created: {repo_url}")
if not args.dry_run and repo_url:
print(f"Repository created: {repo_url}")

View File

@ -4,6 +4,8 @@ set -e
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
QUANTIZED_TYPE="${2:-"$QUANTIZED_TYPE"}"
TOKEN_EMBD_TYPE="${3:-"${TOKEN_EMBD_TYPE}"}"
OUTPUT_TYPE="${4:-"${OUTPUT_TYPE}"}"
QUANTIZED_MODEL=$CONVERTED_MODEL
# Final check if we have a model path
@ -14,6 +16,11 @@ if [ -z "$CONVERTED_MODEL" ]; then
exit 1
fi
if [ -z "$QUANTIZED_TYPE" ]; then
echo "Error: QUANTIZED_TYPE is required" >&2
exit 1
fi
echo $CONVERTED_MODEL
# Process the quantized model filename
@ -26,9 +33,16 @@ else
exit 1
fi
cmake --build ../../build --target llama-quantize -j8
../../build/bin/llama-quantize $CONVERTED_MODEL $QUANTIZED_MODEL $QUANTIZED_TYPE
echo $TOKEN_EMBD_TYPE
echo $OUTPUT_TYPE
CMD_ARGS=("../../build/bin/llama-quantize")
[[ -n "$TOKEN_EMBD_TYPE" ]] && CMD_ARGS+=("--token-embedding-type" "$TOKEN_EMBD_TYPE")
[[ -n "$OUTPUT_TYPE" ]] && CMD_ARGS+=("--output-tensor-type" "$OUTPUT_TYPE")
CMD_ARGS+=("$CONVERTED_MODEL" "$QUANTIZED_MODEL" "$QUANTIZED_TYPE")
"${CMD_ARGS[@]}"
echo "Quantized model saved to: $QUANTIZED_MODEL"

View File

@ -1257,12 +1257,20 @@ static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) {
void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclTensor* acl_dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
if(acl_dst == nullptr) {
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCos, acl_src);
} else {
GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
}
}
void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclTensor* acl_dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
if(acl_dst == nullptr) {
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSin, acl_src);
} else {
GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
}
}
void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
@ -1419,17 +1427,17 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer,
float m, int64_t size, float start, float stop, float step){
int64_t ne[] = {size};
size_t nb[] = {sizeof(float)};
size_t nb[] = {sizeof(uint16_t)};
ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(float));
ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(uint16_t));
void* arange_buffer = arange_allocator.get();
aclTensor* arange_tensor = ggml_cann_create_tensor(
arange_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
arange_buffer, ACL_FLOAT16, sizeof(uint16_t), ne, nb, 1);
aclnn_arange(ctx, arange_tensor, start, stop, step, size);
aclTensor* slope_tensor = ggml_cann_create_tensor(
slope_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
slope_buffer, ACL_FLOAT16, sizeof(uint16_t), ne, nb, 1);
aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT);
@ -2221,13 +2229,54 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
ggml_cann_release_resources(ctx, acl_index, acl_value);
}
/**
* @brief Initializes and caches sine/cosine positional encoding values
* (used in RoPE, Rotary Position Embedding) for attention layers.
*
* This function computes and caches the sin/cos values of
* θ = position * theta_scale for RoPE encoding. The cache is shared
* across attention layers, and only the first attention layer will
* trigger initialization. The cache includes repeated sin/cos values
* with different repeat methods depending on the @param is_neox flag.
*
* Steps performed by this function:
* 1. Identify whether the target tensor belongs to Q/K in attention
* and restrict computation to the first layer only.
* 2. Initialize the theta scale array (arange power freq scaling).
* 3. Allocate sin/cos caches if the max prompt length increases.
* 4. Compute θ = position * theta_scale.
* 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
* 6. Expand sin/cos values by repeat or repeat_interleave depending
* on whether @param is_neox is enabled.
* 7. Store the computed values into persistent buffers
* (ctx.rope_sin_ptr / ctx.rope_cos_ptr).
*
* @param ctx The CANN backend context, holding memory pool,
* stream, and persistent buffers for rope init/cache.
* @param dst The destination ggml_tensor whose computation
* depends on the cached RoPE values (usually Qcur/Kcur).
* @param theta_scale Scalar exponent base for computing theta scale values.
* @param freq_scale Frequency scaling factor, applied to theta scale.
* @param attn_factor Attention scaling factor, applied to sin/cos.
* @param is_neox Whether to use Neox-style repeat strategy
* (dim expansion vs repeat_interleave).
*/
static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
aclTensor* acl_cos_repeat_tensor,
aclTensor* acl_sin_repeat_tensor,
float theta_scale, float freq_scale,
float attn_factor, bool is_neox) {
// int sin/cos cache, cache has different repeat method depond on
// @param.is_neox
bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
// used for accuracy testing
bool is_attention = is_q || is_k;
// just compute in first layer in attention
bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
if(is_attention && !is_fisrt_layer) {
return;
}
ggml_tensor* src0 = dst->src[0]; // input
ggml_tensor* src1 = dst->src[1]; // position
@ -2253,21 +2302,16 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
}
bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
// used for accuracy testing
bool is_attention = is_q || is_k;
if(ctx.init_ptr == nullptr || !is_attention) {
// init theta scale, just one time
if(ctx.rope_init_ptr == nullptr || !is_attention) {
// theta_scale arange, [0,1,...,ne00/2 - 1]
if(ctx.init_ptr != nullptr){
ACL_CHECK(aclrtFree(ctx.init_ptr));
if(ctx.rope_init_ptr != nullptr){
ACL_CHECK(aclrtFree(ctx.rope_init_ptr));
}
ACL_CHECK(aclrtMalloc(&ctx.init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMalloc(&ctx.rope_init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
aclTensor* acl_theta_scale_tensor =
ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
float start = 0;
float step = 1;
@ -2297,67 +2341,55 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale);
}
if(ctx.sin_ptr == nullptr) {
int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
}
// init sin_repeat && cos_repeat, one token just init in 0 layer
if(position_length > ctx.max_prompt_length) {
ctx.max_prompt_length = position_length;
int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
ACL_CHECK(aclrtFree(ctx.sin_ptr));
ACL_CHECK(aclrtFree(ctx.cos_ptr));
ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
int64_t repeat_theta_length = theta_scale_length * ctx.max_prompt_length * 2;
if(ctx.rope_sin_ptr != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_sin_ptr));
ACL_CHECK(aclrtFree(ctx.rope_cos_ptr));
}
ACL_CHECK(aclrtMalloc(&ctx.rope_sin_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMalloc(&ctx.rope_cos_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
}
bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
if(is_fisrt_layer || !is_attention) {
aclTensor* acl_theta_scale_tensor =
ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
aclTensor* acl_theta_scale_tensor =
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
// position
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
src1->data, ggml_cann_type_mapping(src1->type),
ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
// position
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
src1->data, ggml_cann_type_mapping(src1->type),
ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
// power * position
int64_t theta_length = theta_scale_length * position_length;
ggml_cann_pool_alloc theta_allocator(ctx.pool(),
theta_length * sizeof(float_t));
void* theta_buffer = theta_allocator.get();
// power * position
int64_t theta_length = theta_scale_length * position_length;
ggml_cann_pool_alloc theta_allocator(ctx.pool(),
theta_length * sizeof(float_t));
void* theta_buffer = theta_allocator.get();
aclTensor* acl_theta_tensor =
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
theta_ne, theta_nb, GGML_MAX_DIMS);
aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
acl_theta_tensor);
// sin/cos
aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
// release
ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
acl_theta_tensor, acl_sin_tensor, acl_cos_tensor);
}
aclTensor* acl_theta_tensor =
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
theta_ne, theta_nb, GGML_MAX_DIMS);
aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
acl_theta_tensor);
// sin/cos
ggml_cann_pool_alloc sin_allocator(ctx.pool(),
theta_length * sizeof(float_t));
void* sin_buffer = sin_allocator.get();
aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
GGML_MAX_DIMS, ACL_FORMAT_ND);
sin_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
ggml_cann_pool_alloc cos_allocator(ctx.pool(),
theta_length * sizeof(float_t));
void* cos_buffer = cos_allocator.get();
aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
GGML_MAX_DIMS, ACL_FORMAT_ND);
cos_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
// attn_factor
if (attn_factor != 1) {
@ -2365,6 +2397,19 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true);
}
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
size_t sin_reshape_nb[GGML_MAX_DIMS];
sin_reshape_nb[0] = sizeof(float_t);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
}
aclTensor* acl_sin_repeat_tensor =
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_cos_repeat_tensor =
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
// repeat
if (is_neox) {
int64_t repeatsArray[] = {1, 1, 1, 2};
@ -2380,8 +2425,9 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
num_repeats, output_size);
}
// release
ggml_cann_release_resources(ctx, acl_sin_tensor, acl_cos_tensor);
ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor,
acl_cos_repeat_tensor);
}
#ifdef __cplusplus
@ -2435,13 +2481,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
// init cos/sin cache
ggml_cann_pool_alloc sin_allocator(
ctx.pool(), ne00 * ne02 * sizeof(float_t));
ggml_cann_pool_alloc cos_allocator(
ctx.pool(), ne00 * ne02 * sizeof(float_t));
void* sin_buffer = sin_allocator.get();
void* cos_buffer = cos_allocator.get();
// init ctx.rope_cos/rope_sin cache
aclnn_cache_init(ctx, dst, theta_scale, freq_scale, attn_factor, is_neox);
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
size_t sin_reshape_nb[GGML_MAX_DIMS];
@ -2450,13 +2491,11 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
}
aclTensor* acl_sin_reshape_tensor =
ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_cos_reshape_tensor =
ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor,
theta_scale, freq_scale, attn_factor, is_neox);
aclTensor* acl_src = ggml_cann_create_tensor(src0);
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
@ -3141,11 +3180,38 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
ggml_tensor* src0 = dst->src[0]; // q, fp32
ggml_tensor* src1 = dst->src[1]; // k, fp16
ggml_tensor* src2 = dst->src[2]; // v, fp16
ggml_tensor* src0 = dst->src[0]; // q, fp32 | B, N, S, D (uncont) -> B, S, N, D (cont)
ggml_tensor* src1 = dst->src[1]; // k, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
ggml_tensor* src2 = dst->src[2]; // v, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
ggml_tensor* src3 = dst->src[3]; // mask, fp16
// B, N, S, D (uncont) -> B, S, N, D (cont)
int64_t src0_bsnd_ne[GGML_MAX_DIMS];
memcpy(src0_bsnd_ne, src0->ne, GGML_MAX_DIMS * sizeof(int64_t));
size_t src0_bsnd_nb[GGML_MAX_DIMS];
memcpy(src0_bsnd_nb, src0->nb, GGML_MAX_DIMS * sizeof(size_t));
int64_t src1_bsnd_ne[GGML_MAX_DIMS];
memcpy(src1_bsnd_ne, src1->ne, GGML_MAX_DIMS * sizeof(int64_t));
size_t src1_bsnd_nb[GGML_MAX_DIMS];
memcpy(src1_bsnd_nb, src1->nb, GGML_MAX_DIMS * sizeof(size_t));
int64_t src2_bsnd_ne[GGML_MAX_DIMS];
memcpy(src2_bsnd_ne, src2->ne, GGML_MAX_DIMS * sizeof(int64_t));
size_t src2_bsnd_nb[GGML_MAX_DIMS];
memcpy(src2_bsnd_nb, src2->nb, GGML_MAX_DIMS * sizeof(size_t));
auto transpose12 = [](int64_t* ne, size_t* nb) {
int64_t ne_tmp = ne[1];
size_t nb_tmp = nb[1];
ne[1] = ne[2];
nb[1] = nb[2];
ne[2] = ne_tmp;
nb[2] = nb_tmp;
};
transpose12(src0_bsnd_ne, src0_bsnd_nb);
transpose12(src1_bsnd_ne, src1_bsnd_nb);
transpose12(src2_bsnd_ne, src2_bsnd_nb);
float maxBias = 0.0f;
float scaleValue = 1.0f;
float logitSoftcap = 0.0f;
@ -3167,11 +3233,12 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
void* src0_f16_buffer = nullptr;
if(ggml_cann_type_mapping(src0->type) != faDataType){
aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0);
aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne,
src0_bsnd_nb, GGML_MAX_DIMS);
src0_f16_buffer = src0_f16_allocator.alloc(
ggml_nelements(src0) * faElemSize);
int64_t* src0_f16_ne = src0->ne;
int64_t* src0_f16_ne = src0_bsnd_ne;
size_t src0_f16_nb[GGML_MAX_DIMS];
src0_f16_nb[0] = sizeof(uint16_t);
for(int i = 1; i < GGML_MAX_DIMS; ++i){
@ -3185,20 +3252,23 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType);
ggml_cann_release_resources(ctx, acl_src0_f32_tensor);
}else{
acl_src0_f16_tensor = ggml_cann_create_tensor(src0);
acl_src0_f16_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne,
src0_bsnd_nb, GGML_MAX_DIMS);
}
// Step 2: create the acl tensors for src1 (Key), src2 (Value),
// and the direct output from FusedInferAttention
acl_src1_f16_tensor = ggml_cann_create_tensor(src1);
acl_src2_f16_tensor = ggml_cann_create_tensor(src2);
acl_src1_f16_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne,
src1_bsnd_nb, GGML_MAX_DIMS);
acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne,
src2_bsnd_nb, GGML_MAX_DIMS);
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
void* out_f16_buffer = out_f16_allocator.alloc(
ggml_nelements(dst) * faElemSize);
int64_t* out_f16_ne = src0->ne;
int64_t* out_f16_ne = src0_bsnd_ne;
size_t out_f16_nb[GGML_MAX_DIMS];
out_f16_nb[0] = faElemSize;
for(int i = 1; i < GGML_MAX_DIMS; ++i){
@ -3212,88 +3282,81 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
// Step 3: create the PSEShift tensor if needed
// this tensor is considered as mask (f16) in the llama.cpp
aclTensor* bcast_pse_tensor = nullptr;
int64_t bcast_pse_ne[GGML_MAX_DIMS];
size_t bcast_pse_nb[GGML_MAX_DIMS];
ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool());
void* bcast_pse_buffer = nullptr;
if(src3 != nullptr){
bcast_pse_buffer = bcast_pse_allocator.alloc(
ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t));
// Construct the truncated pse tensor (common for prefill/decode)
int64_t trunc_pse_ne[GGML_MAX_DIMS] = {
src3->ne[0], // D
src0->ne[1], // S (number of Q tokens)
src3->ne[2], // mask N
src3->ne[3] // B
};
size_t* trunc_pse_nb = src3->nb;
if(src0->ne[1] > 1){
// Case 1: broadcast pse for prefill stage with multiple head
aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3);
bcast_pse_ne[0] = src3->ne[0];
bcast_pse_ne[1] = src3->ne[1];
bcast_pse_ne[2] = src0->ne[2];
bcast_pse_ne[3] = src3->ne[3];
aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
src3->data, ACL_FLOAT16, sizeof(uint16_t),
trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS
);
int64_t bcast_pse_ne[GGML_MAX_DIMS];
size_t bcast_pse_nb[GGML_MAX_DIMS];
bcast_pse_ne[0] = src3->ne[0]; // D
bcast_pse_ne[1] = src0->ne[1]; // S
bcast_pse_ne[2] = src0->ne[2]; // N (num_heads)
bcast_pse_ne[3] = src3->ne[3]; // B
if (maxBias == 0.0f) {
// When maxBias == 0.0f, use nb = 0 reduce once repeat (Qwen2)
// Construct the bcast tensor (simulate repeat on the head dimension using stride=0)
bcast_pse_nb[0] = sizeof(uint16_t);
for(int i = 1; i < GGML_MAX_DIMS; ++i){
bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
}
bcast_pse_nb[1] = bcast_pse_nb[0] * bcast_pse_ne[0];
bcast_pse_nb[2] = 0; // <---- the head dimension shares the same data
bcast_pse_nb[3] = src3->nb[3];
bcast_pse_tensor = ggml_cann_create_tensor(
bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
int64_t repeats[] = {1, src0->ne[2], 1, 1};
aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats);
ggml_cann_release_resources(ctx, acl_mask_f16_tensor);
}else{
// Case 2: trunc the first row and broadcast pse for decode stage with multiple head
int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]};
size_t* trunc_pse_nb = src3->nb;
aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
src3->data, ACL_FLOAT16, sizeof(uint16_t),
trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);
bcast_pse_ne[0] = src3->ne[0];
bcast_pse_ne[1] = src0->ne[1];
bcast_pse_ne[2] = src0->ne[2];
bcast_pse_ne[3] = src3->ne[3];
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
);
ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
} else {
bcast_pse_nb[0] = sizeof(uint16_t);
for(int i = 1; i < GGML_MAX_DIMS; ++i){
for (int i = 1; i < GGML_MAX_DIMS; i++) {
bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
}
void* bcast_pse_buffer = bcast_pse_allocator.alloc(
ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)
);
bcast_pse_tensor = ggml_cann_create_tensor(
bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
);
int64_t repeats[] = {1, src0->ne[2], 1, 1};
aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats);
ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
}
// Compute the slope if needed. Derived from ggml_cann_softmax().
if(maxBias != 0.0f){
// alibi
// Compute the slope if needed. Derived from ggml_cann_softmax().
const int64_t n_heads = src0->ne[2];
ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));
ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(uint16_t));
void* slope_buffer = slope_allocator.get();
aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias);
int64_t slope_ne[] = {1, 1, n_heads, 1};
size_t slope_nb[GGML_MAX_DIMS];
slope_nb[0] = sizeof(float);
slope_nb[0] = sizeof(uint16_t);
for(int i = 1;i<GGML_MAX_DIMS;i++) {
slope_nb[i] = slope_nb[i-1] * slope_ne[0];
}
aclTensor* slope_tensor = ggml_cann_create_tensor(
slope_buffer, ACL_FLOAT, sizeof(float),
slope_buffer, ACL_FLOAT16, sizeof(uint16_t),
slope_ne, slope_nb, GGML_MAX_DIMS);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, slope_tensor);
ggml_cann_release_resources(ctx, slope_tensor);
ggml_cann_release_resources(ctx, slope_tensor, acl_mask_f16_trunc_tensor);
}
}
@ -3310,7 +3373,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
// double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d)
int64_t preTokens = 65535;
int64_t nextTokens = 65535;
char layout[5] = {'B', 'N', 'S', 'D', 0};
char layout[5] = {'B', 'S', 'N', 'D', 0};
int64_t sparseMode = 0;
int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2;
int64_t blockSize = 0;
@ -3347,32 +3410,9 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
);
// Step 6: post-processing, permute and cast to f32
int64_t new_dim[] = {0, 2, 1, 3};
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
if(ggml_cann_type_mapping(dst->type) != faDataType){
ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool());
perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
void* perm_out_f16_buffer = perm_out_f16_allocator.get();
int64_t* perm_out_f16_ne = dst->ne;
size_t perm_out_f16_nb[GGML_MAX_DIMS];
perm_out_f16_nb[0] = faElemSize;
for(int i = 1; i < GGML_MAX_DIMS; ++i){
perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1];
}
aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor(
perm_out_f16_buffer, faDataType, faElemSize,
perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
aclnn_cast(ctx,
acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor);
}else{
// only need to permute
aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
}
// TODO: when dst is fp16, don't need cast
aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
acl_src1_f16_tensor,
acl_src2_f16_tensor,

View File

@ -368,17 +368,18 @@ struct ggml_backend_cann_context {
std::string name; /**< Name of the device. */
std::string description; /**< Description of the device. */
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
void* init_ptr = nullptr;
void* sin_ptr = nullptr;
void* cos_ptr = nullptr;
int64_t max_prompt_length = 65536;
#ifdef USE_ACL_GRAPH
/// Cached CANN ACL graph used for executing the current ggml computation graph.
std::unique_ptr<ggml_cann_graph> cann_graph;
#endif
cann_task_queue task_queue;
bool async_mode;
bool support_set_rows;
// Rope Cache
void* rope_init_ptr = nullptr;
void* rope_sin_ptr = nullptr;
void* rope_cos_ptr = nullptr;
int64_t max_prompt_length = 0;
// Constant Pool
void* f32_zero_cache = nullptr;
void* f32_one_cache = nullptr;
int64_t f32_zero_cache_element = 0;
@ -398,14 +399,6 @@ struct ggml_backend_cann_context {
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
device, async_mode ? "ON" : "OFF");
support_set_rows = parse_bool(get_env("LLAMA_SET_ROWS").value_or(""));
GGML_LOG_INFO("%s: LLAMA_SET_ROWS is %s\n", __func__, support_set_rows ? "ON" : "OFF");
if (!support_set_rows) {
GGML_LOG_INFO("%s: CANN Graph currently only supports execution when LLAMA_SET_ROWS is ON. "
"Falling back to eager mode.\n", __func__);
}
}
/**
@ -422,14 +415,20 @@ struct ggml_backend_cann_context {
ACL_CHECK(aclrtDestroyStream(streams[i]));
}
}
if(init_ptr != nullptr) {
ACL_CHECK(aclrtFree(init_ptr));
if(rope_init_ptr != nullptr) {
ACL_CHECK(aclrtFree(rope_init_ptr));
}
if(sin_ptr != nullptr) {
ACL_CHECK(aclrtFree(sin_ptr));
if(rope_sin_ptr != nullptr) {
ACL_CHECK(aclrtFree(rope_sin_ptr));
}
if(cos_ptr != nullptr) {
ACL_CHECK(aclrtFree(cos_ptr));
if(rope_cos_ptr != nullptr) {
ACL_CHECK(aclrtFree(rope_cos_ptr));
}
if(f32_zero_cache != nullptr) {
ACL_CHECK(aclrtFree(f32_zero_cache));
}
if(f32_one_cache != nullptr) {
ACL_CHECK(aclrtFree(f32_one_cache));
}
}

View File

@ -1155,7 +1155,7 @@ namespace {
* @note The workspace buffer used in this function is managed globally and reused
* across calls. This reduces overhead from repeated memory allocation and deallocation.
*/
static void weight_format_to_nz(ggml_tensor *tensor, const void *data, size_t offset) {
static void weight_format_to_nz(ggml_tensor *tensor, size_t offset) {
aclTensor* weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne,
tensor->nb, 2, ACL_FORMAT_ND, offset);
uint64_t workspaceSize = 0;
@ -1203,7 +1203,7 @@ static void ggml_backend_cann_buffer_set_tensor(
if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) {
GGML_ASSERT(tensor->ne[2] == 1);
GGML_ASSERT(tensor->ne[3] == 1);
weight_format_to_nz(tensor, data, offset);
weight_format_to_nz(tensor, offset);
}
} else {
void *transform_buffer = malloc(size);
@ -2251,11 +2251,6 @@ static enum ggml_status ggml_backend_cann_graph_compute(
bool use_cann_graph = true;
bool cann_graph_update_required = false;
// check environment LLAMA_SET_ROWS
if (!cann_ctx->support_set_rows) {
use_cann_graph = false;
}
if (use_cann_graph) {
if (cann_ctx->cann_graph == nullptr) {
cann_ctx->cann_graph.reset(new ggml_cann_graph());
@ -2336,7 +2331,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
#ifdef ASCEND_310P
// Q4 && Q8 per group is not suppor on 310p device
// Q4 && Q8 per group is not support on 310p device
return false;
#endif
// only support contiguous for quantized types.
@ -2354,7 +2349,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
#ifdef ASCEND_310P
// Q4 && Q8 per group is not suppor on 310p device
// Q4 && Q8 per group is not support on 310p device
return false;
#endif
// only support contiguous for quantized types.
@ -2496,7 +2491,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
return true;
case GGML_OP_SCALE:
float bias;
memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
memcpy(&bias, (const float *)(op->op_params) + 1, sizeof(float));
return bias == 0.0f; // TODO: support bias != 0.0f
case GGML_OP_SOFT_MAX:
// TODO: support attention sinks [TAG_ATTN_SINKS]
@ -2505,6 +2500,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
}
return true;
case GGML_OP_FLASH_ATTN_EXT:{
#ifdef ASCEND_310P
// FA not support on 310p device
return false;
#endif
// derived from [ggml-cuda.cu]
if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
return false;
@ -2530,8 +2529,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
// DeepSeek MLA
return false;
}
if (op->src[0]->ne[0] % 16 != 0) {
// TODO: padding to support
return false;
}
float logitSoftcap = 0.0f;
memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float));
memcpy(&logitSoftcap, (const float *)(op->op_params) + 2, sizeof(float));
if(logitSoftcap != 0.0f) {
return false;
}

View File

@ -435,7 +435,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
)
if (GGML_RVV)
if (GGML_XTHEADVECTOR)
list(APPEND ARCH_FLAGS -march=rv64gc_xtheadvector -mabi=lp64d)
list(APPEND ARCH_FLAGS -march=rv64gc_zfhmin_xtheadvector -mabi=lp64d)
elseif (GGML_RV_ZFH)
list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -mabi=lp64d)
else()

View File

@ -489,7 +489,7 @@ inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) {
/**
* @see https://github.com/ggml-org/llama.cpp/pull/14037
*/
inline float vec_hsum(float32x4_t v) {
inline static float vec_hsum(float32x4_t v) {
float32x4_t v_temp = v + vec_reve(v);
return v_temp[0] + v_temp[1];
}

View File

@ -2169,94 +2169,117 @@ class tinyBLAS_Q0_PPC {
class tinyBLAS_PPC {
public:
tinyBLAS_PPC(int64_t k,
const float *A, int64_t lda,
const float *B, int64_t ldb,
float *C, int64_t ldc,
const float * A, int64_t lda,
const float * B, int64_t ldb,
float * C, int64_t ldc,
int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
void matmul(int64_t m, int64_t n) {
mnpack(0, m, 0, n);
int64_t mc = 256; int64_t nc = 256; int64_t kc = 256;
if (m % mc == 0 && n % nc == 0 && k % kc == 0) {
matmul_tiled(m, n, mc, nc, kc);
} else {
mnpack(0, m, 0, n);
}
}
private:
void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
inline void vector_permute_store_4(vector float *src, float *vecOffset) {
vector float t1, t2, t3, t4, t5, t6, t7, t8;
t1 = vec_mergeh(src[0], src[1]);
t2 = vec_mergeh(src[2], src[3]);
t3 = vec_mergel(src[0], src[1]);
t4 = vec_mergel(src[2], src[3]);
t5 = vec_xxpermdi(t1, t2, 0);
t6 = vec_xxpermdi(t1, t2, 3);
t7 = vec_xxpermdi(t3, t4, 0);
t8 = vec_xxpermdi(t3, t4, 3);
vec_xst(t5, 0, vecOffset);
vec_xst(t6, 0, vecOffset + 4);
vec_xst(t7, 0, vecOffset + 8);
vec_xst(t8, 0, vecOffset + 12);
}
inline void vector_permute_store_8(vector float *src, float *vecOffset) {
vector float t1, t2, t3, t4, t5, t6, t7, t8;
t1 = vec_mergeh(src[0], src[1]);
t2 = vec_mergeh(src[2], src[3]);
t3 = vec_mergeh(src[4], src[5]);
t4 = vec_mergeh(src[6], src[7]);
t5 = vec_xxpermdi(t1, t2, 0);
t6 = vec_xxpermdi(t3, t4, 0);
t7 = vec_xxpermdi(t1, t2, 3);
t8 = vec_xxpermdi(t3, t4, 3);
vec_xst(t5, 0, vecOffset);
vec_xst(t6, 0, vecOffset + 4);
vec_xst(t7, 0, vecOffset + 8);
vec_xst(t8, 0, vecOffset + 12);
t1 = vec_mergel(src[0], src[1]);
t2 = vec_mergel(src[2], src[3]);
t3 = vec_mergel(src[4], src[5]);
t4 = vec_mergel(src[6], src[7]);
t5 = vec_xxpermdi(t1, t2, 0);
t6 = vec_xxpermdi(t3, t4, 0);
t7 = vec_xxpermdi(t1, t2, 3);
t8 = vec_xxpermdi(t3, t4, 3);
vec_xst(t5, 0, vecOffset + 16);
vec_xst(t6, 0, vecOffset + 20);
vec_xst(t7, 0, vecOffset + 24);
vec_xst(t8, 0, vecOffset + 28);
inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
vec_t vec_C[4];
__builtin_mma_disassemble_acc(vec_C, ACC);
for (int I = 0; I < 4; I++) {
for (int J = 0; J < 4; J++) {
*((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
}
}
}
void packTranspose(const float* a, int64_t lda, int rows, int cols, float* vec) {
inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
vec_t vec_C[4];
__builtin_mma_disassemble_acc(vec_C, ACC);
for (int I = 0; I < 4; I++) {
for (int J = 0; J < 4; J++) {
float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
*c_ptr += *((float *)&vec_C[I]+J);
}
}
}
inline void vector_permute_store_4(vector float * src, float * vecOffset) {
vector float t1, t2, t3, t4, t5, t6, t7, t8;
t1 = vec_mergeh(src[0], src[1]);
t2 = vec_mergeh(src[2], src[3]);
t3 = vec_mergel(src[0], src[1]);
t4 = vec_mergel(src[2], src[3]);
t5 = vec_xxpermdi(t1, t2, 0);
t6 = vec_xxpermdi(t1, t2, 3);
t7 = vec_xxpermdi(t3, t4, 0);
t8 = vec_xxpermdi(t3, t4, 3);
vec_xst(t5, 0, vecOffset);
vec_xst(t6, 0, vecOffset + 4);
vec_xst(t7, 0, vecOffset + 8);
vec_xst(t8, 0, vecOffset + 12);
}
inline void vector_permute_store_8(vector float * src, float * vecOffset) {
vector float t1, t2, t3, t4, t5, t6, t7, t8;
t1 = vec_mergeh(src[0], src[1]);
t2 = vec_mergeh(src[2], src[3]);
t3 = vec_mergeh(src[4], src[5]);
t4 = vec_mergeh(src[6], src[7]);
t5 = vec_xxpermdi(t1, t2, 0);
t6 = vec_xxpermdi(t3, t4, 0);
t7 = vec_xxpermdi(t1, t2, 3);
t8 = vec_xxpermdi(t3, t4, 3);
vec_xst(t5, 0, vecOffset);
vec_xst(t6, 0, vecOffset + 4);
vec_xst(t7, 0, vecOffset + 8);
vec_xst(t8, 0, vecOffset + 12);
t1 = vec_mergel(src[0], src[1]);
t2 = vec_mergel(src[2], src[3]);
t3 = vec_mergel(src[4], src[5]);
t4 = vec_mergel(src[6], src[7]);
t5 = vec_xxpermdi(t1, t2, 0);
t6 = vec_xxpermdi(t3, t4, 0);
t7 = vec_xxpermdi(t1, t2, 3);
t8 = vec_xxpermdi(t3, t4, 3);
vec_xst(t5, 0, vecOffset + 16);
vec_xst(t6, 0, vecOffset + 20);
vec_xst(t7, 0, vecOffset + 24);
vec_xst(t8, 0, vecOffset + 28);
}
void packTranspose(const float * a, int64_t lda, int rows, int cols, float * vec) {
int64_t i, j;
float * aoffsets[8];
float *aoffset = NULL, *boffset = NULL;
float * aoffset = NULL, * boffset = NULL;
__vector_pair arr[8];
vector float c[8][2] = {0};
vector float c1[8] = {0};
vector float c2[8] = {0};
aoffset = const_cast<float*>(a);
aoffset = const_cast<float *>(a);
boffset = vec;
j = (rows >> 3);
if (j > 0) {
do {
aoffsets[0] = aoffset;
for (int it = 1; it< 8; it++)
for (int it = 1; it < 8; it++)
aoffsets[it] = aoffsets[it-1] + lda;
aoffset += 8 * lda;
i = (cols >> 3);
if (i > 0) {
do {
for (int it = 0; it< 8; it++) {
for (int it = 0; it < 8; it++) {
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
c1[it] = c[it][0];
@ -2264,11 +2287,14 @@ class tinyBLAS_PPC {
}
vector_permute_store_8(c1, boffset);
vector_permute_store_8(c2, boffset+32);
for (int it = 0; it < 4; it++)
aoffsets[it] = aoffsets[it] + 8*lda;
vector_permute_store_8(c2, boffset + 32);
boffset += 64;
i--;
if (i > 0) {
for (int it = 0; it < 8; it++) {
aoffsets[it] = aoffsets[it] + 8;
}
}
} while(i > 0);
}
if (cols & 4) {
@ -2295,9 +2321,9 @@ class tinyBLAS_PPC {
c2[it] = c[it][1];
}
vector_permute_store_4(c1, boffset);
vector_permute_store_4(c2, boffset+16);
vector_permute_store_4(c2, boffset + 16);
for (int it = 0; it < 4; it++)
aoffsets[it] += 8*lda;
aoffsets[it] += 8 * lda;
boffset += 32;
i--;
} while(i > 0);
@ -2325,15 +2351,15 @@ class tinyBLAS_PPC {
vec_t vec_A[4], vec_B[4], vec_C[4];
acc_t acc_0;
__builtin_mma_xxsetaccz(&acc_0);
for (int l = 0; l < k; l+=4) {
packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
for (int l = 0; l < k; l += 4) {
packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
}
SAVE_ACC(&acc_0, ii, jj);
save_acc(&acc_0, ii, jj);
}
void KERNEL_4x8(int64_t ii, int64_t jj) {
@ -2341,9 +2367,9 @@ class tinyBLAS_PPC {
acc_t acc_0, acc_1;
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
for (int64_t l = 0; l < k; l+=4) {
packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
for (int64_t l = 0; l < k; l += 4) {
packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
packTranspose(B + (jj * ldb) + l, ldb, 8, 4, (float *)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
@ -2353,8 +2379,8 @@ class tinyBLAS_PPC {
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
__builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
}
SAVE_ACC(&acc_0, ii, jj);
SAVE_ACC(&acc_1, ii, jj+4);
save_acc(&acc_0, ii, jj);
save_acc(&acc_1, ii, jj + 4);
}
void KERNEL_8x4(int64_t ii, int64_t jj) {
@ -2362,9 +2388,9 @@ class tinyBLAS_PPC {
acc_t acc_0, acc_1;
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
for (int64_t l = 0; l < k; l+=4) {
packTranspose(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
for (int64_t l = 0; l < k; l += 4) {
packTranspose(A + (ii * lda) + l, lda, 8, 4, (float *)vec_A);
packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
@ -2374,8 +2400,8 @@ class tinyBLAS_PPC {
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
}
SAVE_ACC(&acc_0, ii, jj);
SAVE_ACC(&acc_1, ii+4, jj);
save_acc(&acc_0, ii, jj);
save_acc(&acc_1, ii + 4, jj);
}
void KERNEL_8x8(int64_t ii, int64_t jj) {
@ -2386,19 +2412,96 @@ class tinyBLAS_PPC {
__builtin_mma_xxsetaccz(&acc_2);
__builtin_mma_xxsetaccz(&acc_3);
for (int l = 0; l < k; l+=8) {
packTranspose(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
packTranspose(A + (ii * lda) + l, lda, 8, 8, (float *)vec_A);
packTranspose(B + (jj * ldb) + l, ldb, 8, 8, (float *)vec_B);
for(int x = 0; x < 16; x+=2) {
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
__builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
__builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x + 1]);
__builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x + 1], vec_B[x]);
__builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x + 1], vec_B[x + 1]);
}
}
save_acc(&acc_0, ii, jj);
save_acc(&acc_1, ii, jj + 4);
save_acc(&acc_2, ii + 4, jj);
save_acc(&acc_3, ii + 4, jj + 4);
}
inline void MMA_16x8(vec_t * vec_A0, vec_t * vec_A1, vec_t * vec_B, acc_t * acc) {
for (int x = 0; x < 16; x += 2) {
__builtin_mma_xvf32gerpp(&acc[0], vec_A0[x + 0], vec_B[x]);
__builtin_mma_xvf32gerpp(&acc[1], vec_A0[x + 0], vec_B[x + 1]);
__builtin_mma_xvf32gerpp(&acc[2], vec_A0[x + 1], vec_B[x]);
__builtin_mma_xvf32gerpp(&acc[3], vec_A0[x + 1], vec_B[x + 1]);
__builtin_mma_xvf32gerpp(&acc[4], vec_A1[x + 0], vec_B[x]);
__builtin_mma_xvf32gerpp(&acc[5], vec_A1[x + 0], vec_B[x + 1]);
__builtin_mma_xvf32gerpp(&acc[6], vec_A1[x + 1], vec_B[x]);
__builtin_mma_xvf32gerpp(&acc[7], vec_A1[x + 1], vec_B[x + 1]);
}
}
void KERNEL(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) {
for (int64_t i = 0; i < mc; i += 16) {
int A_base_addr = (mc / 8) * (i / 8) * 16;
for (int64_t j = 0; j < nc; j += 8) {
int B_base_addr = (nc / 8) * (j / 8) * 16;
acc_t acc[8];
vec_t A0_block[16]; vec_t A1_block[16];
for (int x = 0; x < 8; x++)
__builtin_mma_xxsetaccz(&acc[x]);
for (int64_t l = 0; l < kc; l += 8) {
int A0_block_idx = A_base_addr + (l / 8) * 16;
int A1_block_idx = A0_block_idx + (mc / 8) * 16;
int B_block_idx = B_base_addr + (l / 8) * 16;
vec_t* A0_block = &vec_A[A0_block_idx];
vec_t* A1_block = &vec_A[A1_block_idx];
vec_t* B_block = &vec_B[B_block_idx];
MMA_16x8(A0_block, A1_block, B_block, acc);
}
if (kk == 0) {
save_acc(&acc[0], ii + i, jj + j);
save_acc(&acc[1], ii + i, jj + j + 4);
save_acc(&acc[2], ii + i + 4, jj + j);
save_acc(&acc[3], ii + i + 4, jj + j + 4);
save_acc(&acc[4], ii + i + 8, jj + j);
save_acc(&acc[5], ii + i + 8, jj + j + 4);
save_acc(&acc[6], ii + i + 12, jj + j);
save_acc(&acc[7], ii + i + 12, jj + j + 4);
} else {
add_save_acc(&acc[0], ii + i, jj + j);
add_save_acc(&acc[1], ii + i, jj + j + 4);
add_save_acc(&acc[2], ii + i + 4, jj + j);
add_save_acc(&acc[3], ii + i + 4, jj + j + 4);
add_save_acc(&acc[4], ii + i + 8, jj + j);
add_save_acc(&acc[5], ii + i + 8, jj + j + 4);
add_save_acc(&acc[6], ii + i + 12, jj + j);
add_save_acc(&acc[7], ii + i + 12, jj + j + 4);
}
}
}
}
void matmul_tiled(int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) {
int64_t ytiles = m / mc;
int64_t xtiles = n / nc;
int64_t tiles = xtiles * ytiles;
int64_t duty = (tiles + nth - 1) / nth;
int64_t start = duty * ith;
int64_t end = start + duty;
if (end > tiles) {
end = tiles;
}
for (int64_t job = start; job < end; ++job) {
int64_t ii = (job / xtiles) * mc;
int64_t jj = (job % xtiles) * nc;
for (int64_t kk = 0; kk < k; kk += kc) {
vec_t A_pack[kc * mc / 4];
vec_t B_pack[kc * nc / 4];
packTranspose(A + (ii * lda) + kk, lda, kc, mc, (float *)A_pack);
packTranspose(B + (jj * ldb) + kk, ldb, kc, nc, (float *)B_pack);
KERNEL(ii, jj, mc, nc, kc, A_pack, B_pack, kk);
}
}
SAVE_ACC(&acc_0, ii, jj);
SAVE_ACC(&acc_1, ii, jj+4);
SAVE_ACC(&acc_2, ii+4, jj);
SAVE_ACC(&acc_3, ii+4, jj+4);
}
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
@ -2406,35 +2509,35 @@ class tinyBLAS_PPC {
int n_rem = MIN(n - n0, 8);
int mc = 0, nc = 0;
if (m_rem >= 8 && n_rem >= 8) {
mc = 8;
nc = 8;
gemm<8, 8>(m0, m, n0, n);
mc = 8;
nc = 8;
gemm<8, 8>(m0, m, n0, n);
} else if (m_rem >= 4 && n_rem >= 8) {
mc = 4;
nc = 8;
gemm<4, 8>(m0, m, n0, n);
mc = 4;
nc = 8;
gemm<4, 8>(m0, m, n0, n);
} else if (m_rem >= 8 && n_rem >= 4) {
mc = 8;
nc = 4;
gemm<8, 4>(m0, m, n0, n);
mc = 8;
nc = 4;
gemm<8, 4>(m0, m, n0, n);
} else if (m_rem >= 4 && n_rem >= 4) {
mc = 4;
nc = 4;
gemm<4, 4>(m0, m, n0, n);
mc = 4;
nc = 4;
gemm<4, 4>(m0, m, n0, n);
} else {
mc = (m_rem >= 4) ? 4 : m_rem;
nc = (n_rem >= 4) ? 4 : n_rem;
if (mc == 0 || nc == 0)
return;
return;
gemm_small(m0, m, n0, n, mc, nc);
}
int64_t mp = m0 + ((m - m0) / mc) * mc;
int64_t np = n0 + ((n - n0) / nc) * nc;
mnpack(mp, m, n0, np);
mnpack(m0, m, np, n);
}
}
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
int64_t ytiles = (m - m0) / RM;
int64_t xtiles = (n - n0) / RN;
int64_t tiles = xtiles * ytiles;
@ -2449,30 +2552,30 @@ class tinyBLAS_PPC {
vec_t vec_C[4];
acc_t acc_0;
__builtin_mma_xxsetaccz(&acc_0);
vec_t vec_A[4] {0}, vec_B[4] = {0};
for (int l=0; l<k; l+=4) {
vec_t vec_A[4] = {0}, vec_B[4] = {0};
for (int l = 0; l < k; l += 4) {
/* 'GEMV Forwarding' concept is used in first two conditional loops.
* when one of the matrix has a single row/column, the elements are
* broadcasted, instead of using packing routine to prepack the
* matrix elements.
*/
if (RM == 1) {
float* a = const_cast<float*>(A+(ii)*lda+l);
packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
float * a = const_cast<float *>(A + (ii) * lda + l);
packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
vec_A[0] = (vec_t)vec_xl(0,a);
vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
vec_A[1] = (vec_t)vec_splats(*((float *)&vec_A+1));
vec_A[2] = (vec_t)vec_splats(*((float *)&vec_A+2));
vec_A[3] = (vec_t)vec_splats(*((float *)&vec_A+3));
} else if (RN == 1) {
packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
float* b = const_cast<float*>(B+(jj)*ldb+l);
packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
float * b = const_cast<float *>(B + (jj) * ldb + l);
vec_B[0] = (vec_t)vec_xl(0,b);
vec_B[1] = (vec_t)vec_splats(*((float*)&vec_B+1));
vec_B[2] = (vec_t)vec_splats(*((float*)&vec_B+2));
vec_B[3] = (vec_t)vec_splats(*((float*)&vec_B+3));
vec_B[1] = (vec_t)vec_splats(*((float *)&vec_B+1));
vec_B[2] = (vec_t)vec_splats(*((float *)&vec_B+2));
vec_B[3] = (vec_t)vec_splats(*((float *)&vec_B+3));
} else {
packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
}
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
@ -2482,12 +2585,27 @@ class tinyBLAS_PPC {
__builtin_mma_disassemble_acc(vec_C, &acc_0);
for (int I = 0; I < RM; I++) {
for (int J = 0; J < RN; J++) {
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
*((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
}
}
}
}
template<int RM, int RN>
inline void kernel(int64_t ii, int64_t jj) {
if constexpr(RM == 4 && RN == 4) {
KERNEL_4x4(ii, jj);
} else if constexpr(RM == 4 && RN == 8) {
KERNEL_4x8(ii, jj);
} else if constexpr(RM == 8 && RN == 4) {
KERNEL_8x4(ii, jj);
} else if constexpr(RM == 8 && RN == 8) {
KERNEL_8x8(ii, jj);
} else {
static_assert(false, "RN/RM values not supported");
}
}
template <int RM, int RN>
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t ytiles = (m - m0) / RM;
@ -2496,27 +2614,18 @@ class tinyBLAS_PPC {
int64_t duty = (tiles + nth - 1) / nth;
int64_t start = duty * ith;
int64_t end = start + duty;
if (RM == 4 && RN == 4) {
kernel = &tinyBLAS_PPC::KERNEL_4x4;
} else if (RM == 4 && RN == 8) {
kernel = &tinyBLAS_PPC::KERNEL_4x8;
} else if (RM == 8 && RN == 4) {
kernel = &tinyBLAS_PPC::KERNEL_8x4;
} else if (RM == 8 && RN == 8) {
kernel = &tinyBLAS_PPC::KERNEL_8x8;
}
if (end > tiles)
end = tiles;
for (int64_t job = start; job < end; ++job) {
int64_t ii = m0 + job / xtiles * RM;
int64_t jj = n0 + job % xtiles * RN;
(this->*kernel)(ii, jj);
kernel<RM, RN>(ii, jj);
}
}
const float *const A;
const float *const B;
float *C;
const float * const A;
const float * const B;
float * C;
const int64_t k;
const int64_t lda;
const int64_t ldb;

View File

@ -9003,8 +9003,7 @@ static void ggml_compute_forward_ssm_scan_f32(
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
// allows optimizing the modulo since n_group should be a power of 2
GGML_ASSERT((ng & -ng) == ng);
GGML_ASSERT(nh % ng == 0);
// heads per thread
const int dh = (nh + nth - 1)/nth;
@ -9035,6 +9034,7 @@ static void ggml_compute_forward_ssm_scan_f32(
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
const float dA = expf(dt_soft_plus * A[h]);
const int g = h / (nh / ng); // repeat_interleave
// dim
for (int i1 = 0; i1 < nr; ++i1) {
@ -9057,8 +9057,8 @@ static void ggml_compute_forward_ssm_scan_f32(
// TODO: maybe unroll more?
for (int j = 0; j < 1; j++) {
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
t0 = GGML_F32_VEC_MUL(t0, adA);
t1 = GGML_F32_VEC_MUL(t1, axdt);
@ -9072,6 +9072,9 @@ static void ggml_compute_forward_ssm_scan_f32(
}
sumf = GGML_F32xt_REDUCE_ONE(sum);
#elif defined(__riscv_v_intrinsic)
// todo: RVV implementation
const int np = 0;
#else
const int np = (nc & ~(GGML_F32_STEP - 1));
@ -9087,8 +9090,8 @@ static void ggml_compute_forward_ssm_scan_f32(
for (int i = 0; i < np; i += GGML_F32_STEP) {
for (int j = 0; j < GGML_F32_ARR; j++) {
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
@ -9110,7 +9113,7 @@ static void ggml_compute_forward_ssm_scan_f32(
// d_state
for (int i0 = np; i0 < nc; ++i0) {
const int i = i0 + ii*nc;
const int ig = i0 + (h & (ng - 1))*nc;
const int ig = i0 + g*nc;
// state = prev_state * dA + dB * x
const float state = (s0[i] * dA) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C)
@ -9127,6 +9130,7 @@ static void ggml_compute_forward_ssm_scan_f32(
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
const int g = h / (nh / ng); // repeat_interleave
// dim
for (int i1 = 0; i1 < nr; ++i1) {
@ -9141,8 +9145,8 @@ static void ggml_compute_forward_ssm_scan_f32(
// TODO: what happens when (d_state % svcntw()) != 0?
for (int64_t k = 0; k < nc; k += svcntw()) {
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
@ -9162,7 +9166,7 @@ static void ggml_compute_forward_ssm_scan_f32(
// d_state
for (int i0 = 0; i0 < nc; ++i0) {
const int i = i0 + ii*nc;
const int ig = i0 + (h & (ng - 1))*nc;
const int ig = i0 + g*nc;
// state = prev_state * dA + dB * x
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C)
@ -10023,8 +10027,8 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
int64_t h_stride_2d = head_size * head_size;
#if defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)
// scalar Route to scalar implementation //TODO: Write SVE code
#if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
// scalar Route to scalar implementation //TODO: Write SVE code and RVV code
for (int64_t t = 0; t < T; t++) {
int64_t t_offset = t * t_stride;
int64_t state_offset = head_size * C * (t / (T / n_seqs));

View File

@ -18,6 +18,10 @@
#include <immintrin.h>
#endif
#if defined(__riscv_v_intrinsic)
#include <riscv_vector.h>
#endif
#ifdef __cplusplus
extern "C" {
#endif
@ -94,24 +98,15 @@ extern "C" {
}
#elif defined(__riscv) && defined(__riscv_zfhmin)
static inline float riscv_compute_fp16_to_fp32(ggml_fp16_t h) {
float f;
__asm__(
"fmv.h.x %[f], %[h]\n\t"
"fcvt.s.h %[f], %[f]"
: [f] "=&f" (f)
: [h] "r" (h)
);
return f;
_Float16 hf;
memcpy(&hf, &h, sizeof(ggml_fp16_t));
return hf;
}
static inline ggml_fp16_t riscv_compute_fp32_to_fp16(float f) {
ggml_fp16_t res;
__asm__(
"fcvt.h.s %[f], %[f]\n\t"
"fmv.x.h %[h], %[f]"
: [h] "=&r" (res)
: [f] "f" (f)
);
_Float16 hf = (_Float16)f;
memcpy(&res, &hf, sizeof(ggml_fp16_t));
return res;
}
@ -1170,6 +1165,36 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
#define GGML_F16_VEC_MUL GGML_F32x4_MUL
#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
#elif defined(__riscv_v_intrinsic)
// compatible with vlen >= 128
#define GGML_SIMD
// F32
#define GGML_F32_STEP 16
#define GGML_F32_EPR 4
#define GGML_F32x4 vfloat32m1_t
#define GGML_F32x4_ZERO __riscv_vfmv_v_f_f32m1(0.0f, GGML_F32_EPR)
#define GGML_F32x4_SET1(x) __riscv_vfmv_v_f_f32m1(x, GGML_F32_EPR)
#define GGML_F32x4_LOAD(x) __riscv_vle32_v_f32m1(x, GGML_F32_EPR)
#define GGML_F32x4_STORE(b, v) __riscv_vse32_v_f32m1(b, v, GGML_F32_EPR)
#define GGML_F32x4_FMA(a, b, c) __riscv_vfmacc_vv_f32m1(a, b, c, GGML_F32_EPR)
#define GGML_F32x4_ADD(a, b) __riscv_vfadd_vv_f32m1(a, b, GGML_F32_EPR)
#define GGML_F32x4_MUL(a, b) __riscv_vfmul_vv_f32m1(a, b, GGML_F32_EPR)
#define GGML_F32_VEC GGML_F32x4
#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
#define GGML_F32_VEC_STORE GGML_F32x4_STORE
#define GGML_F32_VEC_FMA GGML_F32x4_FMA
#define GGML_F32_VEC_ADD GGML_F32x4_ADD
#define GGML_F32_VEC_MUL GGML_F32x4_MUL
#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
#endif
// GGML_F32_ARR / GGML_F16_ARR

View File

@ -84,6 +84,16 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
}
// reduce sum1,sum2 to sum1
GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
#elif defined(__riscv_v_intrinsic)
vfloat32m1_t vsum = __riscv_vfmv_v_f_f32m1(0.0f, 1);
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m8(n - i);
vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
vfloat32m8_t prod = __riscv_vfmul_vv_f32m8(ax, ay, avl);
vsum = __riscv_vfredusum_vs_f32m8_f32m1(prod, vsum, avl);
}
sumf += __riscv_vfmv_f_s_f32m1_f32(vsum);
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@ -197,7 +207,7 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
ggml_float sumf = 0.0;
#if defined(GGML_SIMD)
#if defined(GGML_SIMD) && !defined(__riscv_v_intrinsic)
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
@ -325,6 +335,15 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float
vst1q_f32(y + i, val);
sum += (ggml_float)vaddvq_f32(val);
}
#elif defined(__riscv_v_intrinsic)
vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
for (int avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m2(n - i);
vfloat32m2_t val = ggml_v_expf_m2(__riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], avl), max, avl), avl);
__riscv_vse32_v_f32m2(&y[i], val, avl);
vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, avl);
}
return (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
#endif
for (; i < n; ++i) {
float val = expf(x[i] - max);

View File

@ -119,6 +119,14 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
}
#if defined(GGML_SIMD)
#if defined(__riscv_v_intrinsic)
// todo: RVV impl
for (int i = 0; i < n; ++i) {
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
}
#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
@ -149,6 +157,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
}
#endif
#else
for (int i = 0; i < n; ++i) {
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
@ -243,6 +252,14 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
svst1_f32(pg, y + np2, ay1);
}
#elif defined(__riscv_v_intrinsic)
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m8(n - i);
vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, v, ay, avl);
__riscv_vse32_v_f32m8(&y[i], ny, avl);
}
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@ -276,6 +293,13 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
#if defined(GGML_SIMD)
#if defined(__riscv_v_intrinsic)
// todo: RVV impl
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
@ -297,6 +321,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#endif
#else
// scalar
for (int i = 0; i < n; ++i) {
@ -324,6 +349,16 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int
y[i] += x[k][i]*v[k][0];
}
}
#elif defined(__riscv_v_intrinsic)
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m8(n - i);
vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
for (int k = 0; k < GGML_VEC_MAD_UNROLL; k++) {
vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[k][i], avl);
ay = __riscv_vfmadd_vf_f32m8(ax, v[k][0], ay, avl);
}
__riscv_vse32_v_f32m8(&y[i], ay, avl);
}
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@ -375,6 +410,14 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, co
for (int i = 0; i < n; ++i) {
y[i] = x[i]*s + b;
}
#elif defined(__riscv_v_intrinsic)
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m8(n - i);
vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
vfloat32m8_t vb = __riscv_vfmv_v_f_f32m8(b, avl);
vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, s, vb, avl);
__riscv_vse32_v_f32m8(&y[i], ny, avl);
}
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@ -436,6 +479,13 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
ay1 = svmul_f32_m(pg, ay1, vx);
svst1_f32(pg, y + np, ay1);
}
#elif defined(__riscv_v_intrinsic)
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m8(n - i);
vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
vfloat32m8_t ny = __riscv_vfmul_vf_f32m8(ay, v, avl);
__riscv_vse32_v_f32m8(&y[i], ny, avl);
}
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@ -467,6 +517,13 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
#if defined(GGML_SIMD)
#if defined(__riscv_v_intrinsic)
// todo: RVV impl
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
}
#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
@ -486,6 +543,7 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
}
#endif
#else
// scalar
for (int i = 0; i < n; ++i) {
@ -928,7 +986,51 @@ inline static __m128 ggml_v_silu(__m128 x) {
return _mm_div_ps(x, one_plus_exp_neg_x);
}
#endif // __ARM_NEON / __AVX2__ / __SSE2__
#elif defined(__riscv_v_intrinsic)
// adapted from arm limited optimized routine
// the maximum error is 1.45358 plus 0.5 ulps
// numbers above 88.38 will flush to infinity
// numbers beneath -103.97 will flush to zero
inline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) {
const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2(0x1.8p23f, vl);
#ifdef __riscv_xtheadvector
// workaround for compiler bug (gcc 14.3.0: Error: unrecognized opcode `th.vmv1r.v v2,v4')
vfloat32m2_t z = __riscv_vfadd_vf_f32m2(r, 0.0f, vl);
z = __riscv_vfmacc_vf_f32m2(z, 0x1.715476p+0f, x, vl);
#else
const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2(r, 0x1.715476p+0f, x, vl);
#endif
const vfloat32m2_t n = __riscv_vfsub_vv_f32m2(z, r, vl);
const vfloat32m2_t b = __riscv_vfnmsac_vf_f32m2(__riscv_vfnmsac_vf_f32m2(x, 0x1.62e4p-1f, n, vl),
0x1.7f7d1cp-20f, n, vl);
const vuint32m2_t e = __riscv_vsll_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(z), 23, vl);
const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(e, 0x3f800000, vl)); // 1.0f
const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 126.0f, vl);
const vfloat32m2_t u = __riscv_vfmul_vv_f32m2(b, b, vl);
const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2(
__riscv_vfmul_vf_f32m2(b, 0x1.ffffecp-1f, vl),
__riscv_vfmacc_vv_f32m2(
__riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.fffdb6p-2f, vl), 0x1.555e66p-3f, b, vl),
__riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.573e2ep-5f, vl), 0x1.0e4020p-7f, b, vl),
u, vl), u, vl);
if (!__riscv_vcpop_m_b16(c, vl))
return __riscv_vfmacc_vv_f32m2(k, j, k, vl);
const vbool16_t dm = __riscv_vmfle_vf_f32m2_b16(n, 0.0f, vl);
const vuint32m2_t d = __riscv_vmerge_vxm_u32m2(__riscv_vmv_v_x_u32m2(0, vl), 0x82000000, dm, vl);
const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(d, 0x7f000000, vl));
const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsub_vv_u32m2(e, d, vl));
const vfloat32m2_t r1 = __riscv_vmerge_vvm_f32m2(
__riscv_vfmacc_vv_f32m2(k, k, j, vl),
__riscv_vfmul_vv_f32m2(__riscv_vfmacc_vv_f32m2(s2, s2, j, vl), s1, vl),
c, vl);
return __riscv_vmerge_vvm_f32m2(
r1, __riscv_vfmul_vv_f32m2(s1, s1, vl),
__riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 192.0f, vl),
vl);
}
#endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic
inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
for (int i = 0; i < n; ++i) {

View File

@ -94,7 +94,11 @@ if (CUDAToolkit_FOUND)
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
else ()
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1")
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
else()
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static)
endif()
endif()
else()
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)

View File

@ -1,5 +1,6 @@
#include "binbcast.cuh"
#include <cstdint>
#include <utility>
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
return b;
@ -22,13 +23,16 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
return a / b;
}
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3,
/*int s00,*/ int s01, int s02, int s03,
/*int s10,*/ int s11, int s12, int s13) {
const int ne0, const int ne1, const int ne2, const int ne3,
const int ne10, const int ne11, const int ne12, const int ne13,
/*int s0, */ const int s1, const int s2, const int s3,
/*int s00,*/ const int s01, const int s02, const int s03,
/*int s10,*/ const int s11, const int s12, const int s13,
src1_ptrs... src1s) {
const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
@ -46,24 +50,31 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const src0_t * src0_row = src0 + i_src0;
const src1_t * src1_row = src1 + i_src1;
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
const int i10 = i0 % ne10;
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
float result = src0_row ? (float) src0_row[i0] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
} else {
result = bin_op(result, (float)src1[i_src1 + i10]);
}
dst_row[i0] = (dst_t) result;
}
}
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3,
/*int s00,*/ int s01, int s02, int s03,
/*int s10,*/ int s11, int s12, int s13) {
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
const int ne0, const int ne1, const int ne2,const int ne3,
const int ne10, const int ne11, const int ne12, const int ne13,
/*int s0, */ const int s1, const int s2, const int s3,
/*int s00,*/ const int s01, const int s02, const int s03,
/*int s10,*/ const int s11, const int s12, const int s13,
src1_ptrs ... src1s) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
const int i3 = i/(ne2*ne1*ne0);
@ -83,12 +94,190 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const src0_t * src0_row = src0 + i_src0;
const src1_t * src1_row = src1 + i_src1;
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;
const int i10 = i0 % ne10;
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
float result = src0_row ? (float) src0_row[i0] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
} else {
result = bin_op(result, (float)src1[i_src1 + i10]);
}
dst_row[i0] = (dst_t) result;
}
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, size_t... I>
static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
cudaStream_t stream, std::index_sequence<I...>) {
GGML_TENSOR_BINARY_OP_LOCALS
int nr0 = ne10 / ne0;
int nr1 = ne11 / ne1;
int nr2 = ne12 / ne2;
int nr3 = ne13 / ne3;
int nr[4] = { nr0, nr1, nr2, nr3 };
int64_t cne[] = { ne0, ne1, ne2, ne3 };
int64_t cne0[] = { ne00, ne01, ne02, ne03 };
int64_t cne1[] = { ne10, ne11, ne12, ne13 };
size_t cnb[] = { nb0, nb1, nb2, nb3 };
size_t cnb0[] = { nb00, nb01, nb02, nb03 };
size_t cnb1[] = { nb10, nb11, nb12, nb13 };
auto collapse = [](int64_t cne[]) {
cne[0] *= cne[1];
cne[1] = cne[2];
cne[2] = cne[3];
cne[3] = 1;
};
auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
cnb[1] *= cne[1];
cnb[2] *= cne[2];
cnb[3] *= cne[3];
};
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
for (int i = 0; i < 4; i++) {
if (nr[i] != 1) {
break;
}
if (i > 0) {
collapse_nb(cnb, cne);
collapse_nb(cnb0, cne0);
collapse_nb(cnb1, cne1);
collapse(cne);
collapse(cne0);
collapse(cne1);
}
}
}
{
int64_t ne0 = cne[0];
int64_t ne1 = cne[1];
int64_t ne2 = cne[2];
int64_t ne3 = cne[3];
//int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
//int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
//int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
//int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
int64_t ne10 = cne1[0];
int64_t ne11 = cne1[1];
int64_t ne12 = cne1[2];
int64_t ne13 = cne1[3];
size_t nb0 = cnb[0];
size_t nb1 = cnb[1];
size_t nb2 = cnb[2];
size_t nb3 = cnb[3];
size_t nb00 = cnb0[0];
size_t nb01 = cnb0[1];
size_t nb02 = cnb0[2];
size_t nb03 = cnb0[3];
size_t nb10 = cnb1[0];
size_t nb11 = cnb1[1];
size_t nb12 = cnb1[2];
size_t nb13 = cnb1[3];
size_t s0 = nb0 / sizeof(dst_t);
size_t s1 = nb1 / sizeof(dst_t);
size_t s2 = nb2 / sizeof(dst_t);
size_t s3 = nb3 / sizeof(dst_t);
size_t s10 = nb10 / sizeof(src1_t);
size_t s11 = nb11 / sizeof(src1_t);
size_t s12 = nb12 / sizeof(src1_t);
size_t s13 = nb13 / sizeof(src1_t);
size_t s00 = nb00 / sizeof(src0_t);
size_t s01 = nb01 / sizeof(src0_t);
size_t s02 = nb02 / sizeof(src0_t);
size_t s03 = nb03 / sizeof(src0_t);
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
GGML_ASSERT(s0 == 1);
GGML_ASSERT(s00 == 1);
GGML_ASSERT(s10 == 1);
const int block_size = 128;
int64_t hne0 = std::max(ne0 / 2LL, 1LL);
dim3 block_dims;
block_dims.x = std::min<unsigned int>(hne0, block_size);
block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x,
(ne1 + block_dims.y - 1) / block_dims.y,
(ne2 * ne3 + block_dims.z - 1) / block_dims.z);
if (block_nums.z > 65535) {
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
if constexpr (sizeof...(I) > 0) {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12,s13,
(const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12,s13);
}
} else {
if constexpr (sizeof...(I) > 0) {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12,s13,
(const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12,s13);
}
}
}
}
template <typename T>
@ -120,160 +309,14 @@ static __global__ void k_repeat_back(
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
}
template<float (*bin_op)(const float, const float)>
template <float (*bin_op)(const float, const float), int n_fuse = 1>
struct bin_bcast_cuda {
template<typename src0_t, typename src1_t, typename dst_t>
void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
cudaStream_t stream) {
GGML_TENSOR_BINARY_OP_LOCALS
int nr0 = ne10/ne0;
int nr1 = ne11/ne1;
int nr2 = ne12/ne2;
int nr3 = ne13/ne3;
int nr[4] = { nr0, nr1, nr2, nr3 };
// collapse dimensions until first broadcast dimension
int64_t cne[] = {ne0, ne1, ne2, ne3};
int64_t cne0[] = {ne00, ne01, ne02, ne03};
int64_t cne1[] = {ne10, ne11, ne12, ne13};
size_t cnb[] = {nb0, nb1, nb2, nb3};
size_t cnb0[] = {nb00, nb01, nb02, nb03};
size_t cnb1[] = {nb10, nb11, nb12, nb13};
auto collapse = [](int64_t cne[]) {
cne[0] *= cne[1];
cne[1] = cne[2];
cne[2] = cne[3];
cne[3] = 1;
};
auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
cnb[1] *= cne[1];
cnb[2] *= cne[2];
cnb[3] *= cne[3];
};
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
for (int i = 0; i < 4; i++) {
if (nr[i] != 1) {
break;
}
if (i > 0) {
collapse_nb(cnb, cne);
collapse_nb(cnb0, cne0);
collapse_nb(cnb1, cne1);
collapse(cne);
collapse(cne0);
collapse(cne1);
}
}
}
{
int64_t ne0 = cne[0];
int64_t ne1 = cne[1];
int64_t ne2 = cne[2];
int64_t ne3 = cne[3];
//int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
//int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
//int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
//int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
int64_t ne10 = cne1[0];
int64_t ne11 = cne1[1];
int64_t ne12 = cne1[2];
int64_t ne13 = cne1[3];
size_t nb0 = cnb[0];
size_t nb1 = cnb[1];
size_t nb2 = cnb[2];
size_t nb3 = cnb[3];
size_t nb00 = cnb0[0];
size_t nb01 = cnb0[1];
size_t nb02 = cnb0[2];
size_t nb03 = cnb0[3];
size_t nb10 = cnb1[0];
size_t nb11 = cnb1[1];
size_t nb12 = cnb1[2];
size_t nb13 = cnb1[3];
size_t s0 = nb0 / sizeof(dst_t);
size_t s1 = nb1 / sizeof(dst_t);
size_t s2 = nb2 / sizeof(dst_t);
size_t s3 = nb3 / sizeof(dst_t);
size_t s10 = nb10 / sizeof(src1_t);
size_t s11 = nb11 / sizeof(src1_t);
size_t s12 = nb12 / sizeof(src1_t);
size_t s13 = nb13 / sizeof(src1_t);
size_t s00 = nb00 / sizeof(src0_t);
size_t s01 = nb01 / sizeof(src0_t);
size_t s02 = nb02 / sizeof(src0_t);
size_t s03 = nb03 / sizeof(src0_t);
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
GGML_ASSERT(s0 == 1);
GGML_ASSERT(s00 == 1);
GGML_ASSERT(s10 == 1);
const int block_size = 128;
int64_t hne0 = std::max(ne0/2LL, 1LL);
dim3 block_dims;
block_dims.x = std::min<unsigned int>(hne0, block_size);
block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
block_dims.z = std::min(std::min<unsigned int>(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U);
dim3 block_nums(
(hne0 + block_dims.x - 1) / block_dims.x,
(ne1 + block_dims.y - 1) / block_dims.y,
(ne2*ne3 + block_dims.z - 1) / block_dims.z
);
if (block_nums.z > 65535) {
// this is the maximum number of blocks in z dimension, fallback to 1D grid kernel
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
src0_dd, src1_dd, dst_dd,
ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00, */ s01, s02, s03,
/* s10, */ s11, s12, s13);
} else {
k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd,
ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00, */ s01, s02, s03,
/* s10, */ s11, s12, s13);
}
}
launch_bin_bcast_pack<bin_op, src0_t, src1_t, dst_t>(
src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence<n_fuse>{});
}
};
@ -312,7 +355,7 @@ static void ggml_cuda_op_bin_bcast(
}
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat, 0>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
}
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -331,6 +374,68 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}
template <float (*op)(const float, const float), int n_fuse>
static void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
cudaStream_t stream = ctx.stream();
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
launch_bin_bcast_pack<op, float, float, float>(src0, src1, dst,
(const float *) src0->data, (const float *) src1->data, (float *) dst->data,
stream, std::make_index_sequence<n_fuse>{});
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
launch_bin_bcast_pack<op, half, half, half>(src0, src1, dst,
(const half *) src0->data, (const half *) src1->data, (half *) dst->data,
stream, std::make_index_sequence<n_fuse>{});
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
launch_bin_bcast_pack<op, half, float, half>(src0, src1, dst,
(const half *) src0->data, (const float *) src1->data, (half *) dst->data,
stream, std::make_index_sequence<n_fuse>{});
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
launch_bin_bcast_pack<op, half, float, float>(src0, src1, dst,
(const half *) src0->data, (const float *) src1->data, (float *) dst->data,
stream, std::make_index_sequence<n_fuse>{});
} else {
fprintf(stderr,
"%s: unsupported types for fusion: dst: %s, src0: %s, src1: %s\n",
__func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
GGML_ABORT("fatal error");
}
}
void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
GGML_ASSERT(2 <= n_fuse && n_fuse <= 8);
switch (n_fuse) {
case 2:
ggml_cuda_op_fused_binbcast_impl<op_add, 2>(ctx, dst);
break;
case 3:
ggml_cuda_op_fused_binbcast_impl<op_add, 3>(ctx, dst);
break;
case 4:
ggml_cuda_op_fused_binbcast_impl<op_add, 4>(ctx, dst);
break;
case 5:
ggml_cuda_op_fused_binbcast_impl<op_add, 5>(ctx, dst);
break;
case 6:
ggml_cuda_op_fused_binbcast_impl<op_add, 6>(ctx, dst);
break;
case 7:
ggml_cuda_op_fused_binbcast_impl<op_add, 7>(ctx, dst);
break;
case 8:
ggml_cuda_op_fused_binbcast_impl<op_add, 8>(ctx, dst);
break;
default:
GGML_ASSERT(false && "Unsupported n_fuse value");
}
}
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];

View File

@ -7,3 +7,5 @@ void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);

View File

@ -107,9 +107,9 @@ constexpr bool ggml_cuda_has_arch(const int arch) {
return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
}
constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) {
constexpr int ggml_cuda_highest_compiled_arch_impl(const int /*arch*/, const int cur) {
if (cur == 0) {
GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch);
return -1;
}
return cur;
}
@ -420,16 +420,28 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
template<int width = WARP_SIZE>
static __device__ __forceinline__ int warp_reduce_all(int x) {
#ifdef GGML_USE_HIP
if (width == ggml_cuda_get_physical_warp_size()) {
return __all_sync(0xffffffff, x);
} else {
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x = x && __shfl_xor_sync(0xffffffff, x, offset, width);
for (int offset = width/2; offset > 0; offset >>= 1) {
x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
}
return x;
}
}
template<int width = WARP_SIZE>
static __device__ __forceinline__ int warp_reduce_any(int x) {
if (width == ggml_cuda_get_physical_warp_size()) {
return __any_sync(0xffffffff, x);
} else {
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x = __shfl_xor_sync(0xffffffff, x, offset, width) || x;
}
return x;
}
return x;
#else
static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented");
return __all_sync(0xffffffff, x);
#endif // GGML_USE_HIP
}
template<int width = WARP_SIZE>

View File

@ -0,0 +1,171 @@
#include "conv2d.cuh"
struct conv_params {
const int64_t IW, IH;
const int64_t OW, OH;
const int64_t KW, KH;
const int64_t ST_X, ST_Y;
const int64_t PD_X, PD_Y;
const int64_t DL_X, DL_Y;
const int64_t IC, OC;
const int64_t B;
const int64_t TOTAL;
};
struct kernel_bounds {
int64_t y_min, y_max;
int64_t x_min, x_max;
};
__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) {
return (a > b) ? a : b;
}
__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) {
return (a < b) ? a : b;
}
__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) {
kernel_bounds bounds;
bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
return bounds;
}
__device__ __forceinline__ int calculate_input_coord(int64_t out_coord,
int64_t kern_coord,
int64_t stride,
int64_t dilation,
int64_t padding) {
return out_coord * stride + kern_coord * dilation - padding;
}
struct whcn_layout {
__device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x;
}
__device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) {
return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx;
}
__device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x;
}
__device__ static void unpack_indices(int64_t global_idx,
const conv_params & P,
int64_t & n,
int64_t & c,
int64_t & out_y,
int64_t & out_x) {
out_x = global_idx % P.OW;
out_y = (global_idx / P.OW) % P.OH;
c = (global_idx / (P.OW * P.OH)) % P.OC;
n = global_idx / (P.OW * P.OH * P.OC);
}
};
template <typename T, typename Layout>
static __global__ void conv2d_kernel(const float * __restrict__ input,
const T * __restrict__ kernel,
float * __restrict__ output,
const conv_params P) {
const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (global_idx >= P.TOTAL) {
return;
}
int64_t n, c_out, out_y, out_x;
Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);
T acc = 0;
for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) {
const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y);
for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
T input_val;
if (std::is_same<T, half>::value) {
input_val = __float2half(input[Layout::input_index(n, c_in, in_y, in_x, P)]);
} else {
input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
}
T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
acc += (input_val * kernel_val);
}
}
}
// [N, OC, OH, OW]
output[Layout::output_index(n, c_out, out_y, out_x, P)] = (float) acc;
}
template <typename T>
static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
conv2d_kernel<T, whcn_layout><<<blocks, CUDA_CONV2D_BLOCK_SIZE, 0, st>>>(X_D, K_D, Y_D, P);
}
static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
conv2d_cuda<half>(X_D, K_D, Y_D, P, st);
}
static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
conv2d_cuda<float>(X_D, K_D, Y_D, P, st);
}
void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * kernel = dst->src[0];
const ggml_tensor * input = dst->src[1];
float * K_D = (float *) kernel->data;
const float * X_D = (const float *) input->data;
float * Y_D = (float *) dst->data;
GGML_ASSERT(ggml_is_contiguous(kernel));
GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32);
// same number of input channels
GGML_ASSERT(input->ne[2] == kernel->ne[2]);
cudaStream_t st = ctx.stream();
const int32_t * p = (const int32_t *) dst->op_params;
const int ST_X = p[0]; // stride_x
const int ST_Y = p[1]; // stride_y
const int PD_X = p[2]; // padding_x
const int PD_Y = p[3]; // padding_y
const int DL_X = p[4]; // dilation_x
const int DL_Y = p[5]; // dilation_y
// No cwhn
GGML_ASSERT(p[6] == false);
const int IW = input->ne[0]; // input_w
const int IH = input->ne[1]; // input_h
const int OW = dst->ne[0]; // output_w
const int OH = dst->ne[1]; // output_h
const int KW = kernel->ne[0]; // kernel_w
const int KH = kernel->ne[1]; // kernel_h
const int IC = input->ne[2]; // input_channels
const int OC = kernel->ne[3]; // ouptut_chanles
const int B = input->ne[3]; // n_batches
const int64_t total = B * OC * OH * OW;
conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };
if (kernel->type == GGML_TYPE_F16) {
conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st);
} else {
conv2d_cuda_f32(X_D, K_D, Y_D, params, st);
}
}

View File

@ -0,0 +1,5 @@
#pragma once
#include "common.cuh"
#define CUDA_CONV2D_BLOCK_SIZE 256
void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -12,6 +12,7 @@
#include "ggml-cuda/clamp.cuh"
#include "ggml-cuda/concat.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"
#include "ggml-cuda/conv2d.cuh"
#include "ggml-cuda/conv2d-dw.cuh"
#include "ggml-cuda/conv2d-transpose.cuh"
#include "ggml-cuda/convert.cuh"
@ -204,6 +205,8 @@ static ggml_cuda_device_info ggml_cuda_init() {
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
#endif // GGML_CUDA_FORCE_CUBLAS
GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
std::vector<std::pair<int, std::string>> turing_devices_without_mma;
for (int id = 0; id < info.device_count; ++id) {
int device_vmm = 0;
@ -261,7 +264,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
info.devices[id].cc = 100*prop.major + 10*prop.minor;
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
#endif // defined(GGML_USE_HIP)
std::string device_name(prop.name);
if (device_name == "NVIDIA GeForce MX450") {
turing_devices_without_mma.push_back({ id, device_name });
} else if (device_name == "NVIDIA GeForce MX550") {
turing_devices_without_mma.push_back({ id, device_name });
} else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
turing_devices_without_mma.push_back({ id, device_name });
}
#endif // defined(GGML_USE_HIP)
}
if (ggml_cuda_highest_compiled_arch(GGML_CUDA_CC_TURING) >= GGML_CUDA_CC_TURING && !turing_devices_without_mma.empty()) {
GGML_LOG_INFO("The following devices will have suboptimal performance due to a lack of tensor cores:\n");
for (size_t device_pos = 0; device_pos < turing_devices_without_mma.size(); device_pos++) {
GGML_LOG_INFO(
" Device %d: %s\n", turing_devices_without_mma[device_pos].first, turing_devices_without_mma[device_pos].second.c_str());
}
GGML_LOG_INFO(
"Consider compiling with CMAKE_CUDA_ARCHITECTURES=61-virtual;80-virtual and DGGML_CUDA_FORCE_MMQ to force the use of the Pascal code for Turing.\n");
}
for (int id = 0; id < info.device_count; ++id) {
@ -2431,6 +2452,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst);
break;
case GGML_OP_CONV_2D:
ggml_cuda_op_conv2d(ctx, dst);
break;
case GGML_OP_CONV_2D_DW:
ggml_cuda_op_conv2d_dw(ctx, dst);
break;
@ -2797,9 +2821,14 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false;
}
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
const ggml_tensor *add = nullptr;
if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {
add = cgraph->nodes[node_idx+2];
}
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
@ -2811,6 +2840,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false;
}
if (add && (add->src[0]->type != GGML_TYPE_F32 ||
add->src[1]->type != GGML_TYPE_F32 ||
add->type != GGML_TYPE_F32) ) {
return false;
}
//if rms norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
return false;
@ -2821,6 +2856,10 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false;
}
if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) {
return false;
}
return true;
}
@ -2867,7 +2906,46 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) {
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
if (node->op == GGML_OP_ADD) {
int n_fuse = 0;
ggml_op ops[8];
std::fill(ops, ops + 8, GGML_OP_ADD);
for (; n_fuse <= 6; ++n_fuse){
if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
break;
}
if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
break;
}
if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
break;
}
}
n_fuse++;
if (n_fuse > 1) {
for (int j = 0; j < n_fuse - 1; ++j) {
node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
}
cgraph->nodes[i + n_fuse - 1]->data = node->data;
ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
i += n_fuse - 1;
continue;
}
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
i += 2;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
i++;
continue;
@ -3086,7 +3164,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
return false;
}
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) || defined(GGML_USE_HIP)
cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
if (err != cudaSuccess) {
// clear the error
@ -3481,6 +3559,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
}
case GGML_OP_IM2COL:
case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D:

View File

@ -3,6 +3,140 @@
#include <vector>
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
struct mmq_ids_helper_store {
uint32_t data;
__device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
data = (it & 0x003FFFFF) | (iex_used << 22);
}
__device__ uint32_t it() const {
return data & 0x003FFFFF;
}
__device__ uint32_t iex_used() const {
return data >> 22;
}
};
static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
// Helper function for mul_mat_id, converts ids to a more convenient format.
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
// ids_dst describes the same mapping but for the dst tensor.
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
template <int n_expert_used_template>
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mmq_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
const int expert = blockIdx.x;
extern __shared__ char data_mmq_ids_helper[];
mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
int nex_prev = 0; // Number of columns for experts with a lower index.
int it_compact = 0; // Running index for the compact slice of this expert.
if constexpr (n_expert_used_template == 0) {
// Generic implementation:
for (int it = 0; it < n_tokens; ++it) {
int iex_used = -1; // The index at which the expert is used, if any.
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
const int expert_used = ids[it*si1 + iex];
nex_prev += expert_used < expert;
if (expert_used == expert) {
iex_used = iex;
}
}
if (iex_used != -1) {
store[it_compact] = mmq_ids_helper_store(it, iex_used);
}
if (warp_reduce_any<warp_size>(iex_used != -1)) {
it_compact++;
}
}
} else {
// Implementation optimized for specific numbers of experts used:
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
const int it = it0 + threadIdx.x / neu_padded;
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
ids[it*si1 + iex] : INT_MAX;
const int iex_used = expert_used == expert ? iex : -1;
nex_prev += expert_used < expert;
// Whether the threads at this token position have used the expert:
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
int it_compact_add_lower = 0;
#pragma unroll
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
if (threadIdx.x >= offset) {
it_compact_add_lower += tmp;
}
}
if (iex_used != -1) {
store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used);
}
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
}
}
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
const mmq_ids_helper_store store_it = store[itc];
const int it = store_it.it();
const int iex_used = store_it.iex_used();
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
}
if (threadIdx.x != 0) {
return;
}
expert_bounds[expert] = nex_prev;
if (expert < gridDim.x - 1) {
return;
}
expert_bounds[gridDim.x] = nex_prev + it_compact;
}
template <int n_expert_used_template>
static void launch_mmq_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store");
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store");
const int id = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[id].warp_size;
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper<n_expert_used_template>, smpbo);
const dim3 num_blocks(n_experts, 1, 1);
const dim3 block_size(warp_size, 1, 1);
const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
GGML_ASSERT(nbytes_shared <= smpbo);
mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
}
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
switch (args.type_x) {
case GGML_TYPE_Q4_0:
@ -137,7 +271,7 @@ void ggml_cuda_mul_mat_q(
ne00, ne01, ne1, s01, ne11, s1,
ne02, ne12, s02, s12, s2,
ne03, ne13, s03, s13, s3,
use_stream_k};
use_stream_k, ne1};
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
return;
}
@ -148,54 +282,50 @@ void ggml_cuda_mul_mat_q(
const int64_t n_expert_used = ids->ne[0];
const int64_t ne_get_rows = ne12 * n_expert_used;
GGML_ASSERT(ne1 == n_expert_used);
std::vector<char> ids_host(ggml_nbytes(ids));
std::vector<int32_t> ids_src1_host;
ids_src1_host.reserve(ne_get_rows);
std::vector<int32_t> ids_dst_host;
ids_dst_host.reserve(ne_get_rows);
std::vector<int32_t> tokens_per_expert_host(ne02);
std::vector<int32_t> expert_bounds_host(ne02 + 1);
ggml_cuda_pool_alloc<int32_t> ids_buf_dev(ctx.pool());
ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows);
ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows);
ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1);
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
{
GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));
const int si1 = ids->nb[1] / ggml_element_size(ids);
const int sis1 = nb12 / nb11;
for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices
for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens
for (int64_t iex = 0; iex < n_expert_used; ++iex) {
const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]);
assert(expert_to_use >= 0 && expert_to_use < ne02);
if (expert_to_use == i02) {
ids_src1_host.push_back(i12*(nb12/nb11) + iex % ne11);
ids_dst_host.push_back(i12*ne1 + iex);
tokens_per_expert_host[i02]++;
break;
}
}
switch (n_expert_used) {
case 2:
launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 4:
launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 6:
launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 8:
launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 16:
launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 32:
launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
default:
launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
}
CUDA_CHECK(cudaGetLastError());
}
int32_t cumsum = 0;
for (int64_t i = 0; i < ne02; ++i) {
expert_bounds_host[i] = cumsum;
cumsum += tokens_per_expert_host[i];
}
expert_bounds_host[ne02] = cumsum;
std::vector<int32_t> ids_buf_host;
ids_buf_host.reserve(ids_src1_host.size() + ids_dst_host.size() + expert_bounds_host.size());
ids_buf_host.insert(ids_buf_host.end(), ids_src1_host.begin(), ids_src1_host.end());
ids_buf_host.insert(ids_buf_host.end(), ids_dst_host.begin(), ids_dst_host.end());
ids_buf_host.insert(ids_buf_host.end(), expert_bounds_host.begin(), expert_bounds_host.end());
ids_buf_dev.alloc(ids_buf_host.size() + get_mmq_x_max_host(cc)); // Expert bounds are padded on device.
CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_buf_host.data(), ids_buf_host.size()*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
const int32_t * ids_src1_dev = ids_buf_dev.ptr;
const int32_t * ids_dst_dev = ids_src1_dev + ids_src1_host.size();
const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size();
const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
@ -208,7 +338,7 @@ void ggml_cuda_mul_mat_q(
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[2] / ts_src1;
quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type,
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
CUDA_CHECK(cudaGetLastError());
}
@ -218,11 +348,11 @@ void ggml_cuda_mul_mat_q(
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
const mmq_args args = {
src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d,
src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d,
ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
ne02, ne02, s02, s12, s2,
ne03, ne13, s03, s13, s3,
use_stream_k};
use_stream_k, ne12};
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
}
@ -262,7 +392,7 @@ void ggml_cuda_op_mul_mat_q(
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
1, 1, 0, 0, 0,
1, 1, 0, 0, 0,
use_stream_k};
use_stream_k, src1_ncols};
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);

View File

@ -3138,7 +3138,8 @@ static __global__ void mul_mat_q(
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const int ncols_max) {
// Skip unused template specializations for faster compilation:
if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
@ -3152,7 +3153,7 @@ static __global__ void mul_mat_q(
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int mmq_y = get_mmq_y_device();
const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; // Number of tiles x
const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
// Initialize the ids for writing back data with just the index.
@ -3376,7 +3377,8 @@ template <ggml_type type, int mmq_x, bool need_check>
static __global__ void mul_mat_q_stream_k_fixup(
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
const int ncols_max) {
constexpr int mmq_y = get_mmq_y_device();
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
@ -3387,7 +3389,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
const int ntx = (ncols_dst + mmq_x - 1) / mmq_x;
const int ntx = (ncols_max + mmq_x - 1) / mmq_x;
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
const int bidx0 = blockIdx.x;
@ -3528,7 +3530,7 @@ struct mmq_args {
int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
bool use_stream_k;
bool use_stream_k; int64_t ncols_max;
};
template<ggml_type type>
@ -3558,7 +3560,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, true>), nbytes_shared);
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x;
const int ntzw = args.nchannels_y * args.nsamples_y;
const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
@ -3574,14 +3576,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
args.ncols_max);
} else {
constexpr bool need_check = true;
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
args.ncols_max);
}
return;
}
@ -3601,7 +3605,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
args.ncols_max);
if (!fixup_needed) {
return;
@ -3609,14 +3614,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
args.ncols_max);
} else {
constexpr bool need_check = true;
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
args.ncols_max);
if (!fixup_needed) {
return;
@ -3624,7 +3631,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
args.ncols_max);
}
}
@ -3649,7 +3657,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
continue;
}
const int ntiles_x = (args.ncols_y + mmq_x - 1) / mmq_x;
const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x;
if (ntiles_x < ntiles_x_best) {
mmq_x_best = mmq_x;

View File

@ -104,12 +104,30 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
}
}
template <int block_size, bool do_multiply = false>
static __global__ void rms_norm_f32(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0,
const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0,
const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) {
template <int block_size, bool do_multiply = false, bool do_add = false>
static __global__ void rms_norm_f32(const float * x, float * dst,
const int ncols,
const int64_t stride_row,
const int64_t stride_channel,
const int64_t stride_sample,
const float eps,
const float * mul = nullptr,
const int64_t mul_stride_row = 0,
const int64_t mul_stride_channel = 0,
const int64_t mul_stride_sample = 0,
const int mul_ncols = 0,
const int mul_nrows = 0,
const int mul_nchannels = 0,
const int mul_nsamples = 0,
const float * add = nullptr,
const int64_t add_stride_row = 0,
const int64_t add_stride_channel = 0,
const int64_t add_stride_sample = 0,
const int add_ncols = 0,
const int add_nrows = 0,
const int add_nchannels = 0,
const int add_nsamples = 0) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
@ -118,6 +136,8 @@ static __global__ void rms_norm_f32(
const int sample = blockIdx.z;
const int tid = threadIdx.x;
static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying");
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
@ -128,6 +148,13 @@ static __global__ void rms_norm_f32(
mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
}
if constexpr (do_add) {
const int add_row = row % add_nrows;
const int add_channel = channel % add_nchannels;
const int add_sample = sample % add_nsamples;
add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
}
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
@ -154,7 +181,11 @@ static __global__ void rms_norm_f32(
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
if constexpr (do_multiply) {
if constexpr (do_multiply && do_add) {
const int mul_col = col % mul_ncols;
const int add_col = col % add_ncols;
dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
} else if constexpr (do_multiply) {
const int mul_col = col % mul_ncols;
dst[col] = scale * x[col] * mul[mul_col];
} else {
@ -331,23 +362,70 @@ static void rms_norm_f32_cuda(
}
}
static void rms_norm_mul_f32_cuda(
const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples,
const float eps, cudaStream_t stream) {
static void rms_norm_mul_f32_cuda(const float * x,
const float * mul,
const float * add,
float * dst,
const int ncols,
const int nrows,
const int nchannels,
const int nsamples,
const int64_t stride_row,
const int64_t stride_channel,
const int64_t stride_sample,
const int64_t mul_stride_row,
const int64_t mul_stride_channel,
const int64_t mul_stride_sample,
const int mul_ncols,
const int mul_nrows,
const int mul_nchannels,
const int mul_nsamples,
const int64_t add_stride_row,
const int64_t add_stride_channel,
const int64_t add_stride_sample,
const int add_ncols,
const int add_nrows,
const int add_nchannels,
const int add_nsamples,
const float eps,
cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (mul == nullptr) {
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
return;
}
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
if (add == nullptr) {
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
ncols, stride_row, stride_channel, stride_sample, eps,
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
ncols, stride_row, stride_channel, stride_sample, eps,
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
}
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
ncols, stride_row, stride_channel, stride_sample, eps,
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
add, add_stride_row, add_stride_channel, add_stride_sample,
add_ncols, add_nrows, add_nchannels, add_nsamples);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
ncols, stride_row, stride_channel, stride_sample, eps,
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
add, add_stride_row, add_stride_channel, add_stride_sample,
add_ncols, add_nrows, add_nchannels, add_nsamples);
}
}
}
@ -491,7 +569,102 @@ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor *
const int mul_nchannels = mul_src->ne[2];
const int mul_nsamples = mul_src->ne[3];
rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream);
rms_norm_mul_f32_cuda(src0_d, mul_d, nullptr, dst_d,
ne00, ne01, ne02, ne03,
/*s00*/ s01, s02, s03,
/*mul_s00*/ mul_s01, mul_s02, mul_s03,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
/*add_s00*/ 0, 0, 0,
0, 0, 0, 0,
eps, stream);
}
void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
ggml_tensor * dst,
ggml_tensor * mul_tensor,
ggml_tensor * add_tensor) {
const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
float eps = 0.0f;
memcpy(&eps, dst->op_params, sizeof(float));
const float * src0_d = (const float *) rms_norm_src->data;
const float * mul_d = nullptr;
const ggml_tensor * mul_src = nullptr;
if (mul_tensor->src[0] == dst) {
mul_d = (float *) mul_tensor->src[1]->data;
mul_src = mul_tensor->src[1];
} else if (mul_tensor->src[1] == dst) {
mul_d = (float *) mul_tensor->src[0]->data;
mul_src = mul_tensor->src[0];
} else {
GGML_ASSERT(false);
}
const float * add_d = nullptr;
const ggml_tensor * add_src = nullptr;
if (add_tensor->src[0] == mul_tensor) {
add_d = (float *) add_tensor->src[1]->data;
add_src = add_tensor->src[1];
} else if (add_tensor->src[1] == mul_tensor) {
add_d = (float *) add_tensor->src[0]->data;
add_src = add_tensor->src[0];
} else {
GGML_ASSERT(false);
}
float * dst_d = (float *) add_tensor->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
GGML_ASSERT(add_tensor->type == GGML_TYPE_F32);
GGML_ASSERT(eps >= 0.0f);
const int64_t ne00 = rms_norm_src->ne[0];
const int64_t ne01 = rms_norm_src->ne[1];
const int64_t ne02 = rms_norm_src->ne[2];
const int64_t ne03 = rms_norm_src->ne[3];
const size_t ts0 = ggml_type_size(rms_norm_src->type);
GGML_ASSERT(rms_norm_src->nb[0] == ts0);
const int64_t s01 = rms_norm_src->nb[1] / ts0;
const int64_t s02 = rms_norm_src->nb[2] / ts0;
const int64_t s03 = rms_norm_src->nb[3] / ts0;
const size_t ts_mul = ggml_type_size(mul_src->type);
GGML_ASSERT(mul_src->nb[0] == ts_mul);
const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
const int mul_ncols = mul_src->ne[0];
const int mul_nrows = mul_src->ne[1];
const int mul_nchannels = mul_src->ne[2];
const int mul_nsamples = mul_src->ne[3];
const size_t ts_add = ggml_type_size(add_src->type);
GGML_ASSERT(add_src->nb[0] == ts_add);
const int64_t add_s01 = add_src->nb[1] / ts_add;
const int64_t add_s02 = add_src->nb[2] / ts_add;
const int64_t add_s03 = add_src->nb[3] / ts_add;
const int add_ncols = add_src->ne[0];
const int add_nrows = add_src->ne[1];
const int add_nchannels = add_src->ne[2];
const int add_nsamples = add_src->ne[3];
rms_norm_mul_f32_cuda(src0_d, mul_d,add_d,dst_d,
ne00,ne01, ne02, ne03,
/*s00*/ s01, s02, s03,
/*mul_s00*/ mul_s01, mul_s02, mul_s03,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
/*add_s00*/ add_s01, add_s02, add_s03,
add_ncols, add_nrows, add_nchannels, add_nsamples,
eps, stream);
}
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View File

@ -8,6 +8,11 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor);
void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
ggml_tensor * dst,
ggml_tensor * mul_tensor,
ggml_tensor * add_tensor);
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -129,7 +129,7 @@ __global__ void __launch_bounds__(d_state, 1)
const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
const int seq_idx = blockIdx.y;
const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));

View File

@ -28,7 +28,58 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
return ((const int *) x)[i32]; // assume at least 4 byte alignment
}
// q4 contains 8 indices with 4 bit each.
// This function selects those bytes from table that are at those indices and returns them as int2.
// The first int contains the bytes with even indices in q4, the second int contains the bytes with odd indices in q4.
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
#if defined(GGML_USE_HIP)
// Load the 16-byte table into four 32-bit unsigned integers.
const uint32_t *values = (const uint32_t *)table;
const uint32_t q_even = q4;
const uint32_t q_odd = (q4 >> 4);
// Perform lookups in the lower half of the table (indices 0-7).
uint32_t v_even_low = __builtin_amdgcn_perm(values[1], values[0], q_even & 0x07070707);
uint32_t v_odd_low = __builtin_amdgcn_perm(values[1], values[0], q_odd & 0x07070707);
// Perform lookups in the upper half of the table (indices 8-15).
uint32_t v_even_high = __builtin_amdgcn_perm(values[3], values[2], q_even & 0x07070707);
uint32_t v_odd_high = __builtin_amdgcn_perm(values[3], values[2], q_odd & 0x07070707);
// Select between the low and high results based on the MSB of each index nibble.
uint32_t mask_even = 0x03020100 | ((q_even & 0x08080808) >> 1);
uint32_t res_x = __builtin_amdgcn_perm(v_even_high, v_even_low, mask_even);
uint32_t mask_odd = 0x03020100 | ((q_odd & 0x08080808) >> 1);
uint32_t res_y = __builtin_amdgcn_perm(v_odd_high, v_odd_low, mask_odd);
return make_int2(res_x, res_y);
#elif !defined(GGML_USE_MUSA)
// CUDA does not have an instruction for selecting bytes with 4 bit indices.
// However, __byte_perm is an instruction that selects bytes with 3 bit indices that can be used instead.
const uint32_t * table32 = (const uint32_t *) table;
// __byte_perm selects bytes based on the lower 16 bits in its third argument.
// Therefore, do 2 iterations over the 32 bits in q4 with 0 and 16 shift.
// To handle the fourth bit, first call _byte_perm both for the low and the high 64 bit of table, using the low 3 bits.
// Then, call __byte_perm again to select from the low and high bytes based on the fourth bit.
uint32_t tmp[2];
const uint32_t low_high_selection_indices = (0x32103210 | ((q4 & 0x88888888) >> 1));
#pragma unroll
for (uint32_t i = 0; i < 2; ++i) {
const uint32_t shift = 16 * i;
const uint32_t low = __byte_perm(table32[0], table32[1], q4 >> shift);
const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift);
tmp[i] = __byte_perm(low, high, low_high_selection_indices >> shift);
}
// tmp contains the bytes from tyble in the same order as the 4 bit indices in q4.
// However, for the result we need ints with all even/odd 4 bit indices in q4.
// Therefore, 2 more calls to __byte_perm to put the bytes in the correct order.
return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531));
#else
// Generic implementation.
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
const int8_t * q0_8 = (const int8_t *) &q0_32;
const char4 val0_8 = make_char4(
@ -40,6 +91,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
#endif
}
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called

View File

@ -22,7 +22,10 @@
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define __all_sync(mask, var) __all(var)
#define __any_sync(mask, var) __any(var)
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx

View File

@ -249,6 +249,7 @@ typedef struct {
uint64_t nb33;
int32_t ne1;
int32_t ne2;
int32_t ne3;
float scale;
float max_bias;
float m0;
@ -257,6 +258,11 @@ typedef struct {
float logit_softcap;
} ggml_metal_kargs_flash_attn_ext;
typedef struct {
int32_t nrows;
int32_t ne20;
} ggml_metal_kargs_flash_attn_ext_reduce;
typedef struct {
int32_t ne00;
int32_t ne02;
@ -320,40 +326,31 @@ typedef struct {
} ggml_metal_kargs_mul_mv_ext;
typedef struct {
int32_t ne02;
int32_t ne10;
int32_t ne11; // n_expert_used (bcast)
uint64_t nb11;
uint64_t nb12;
int32_t neh11; // n_tokens
uint64_t nbh11;
int32_t ne21; // n_tokens
int32_t ne20; // n_expert_used
uint64_t nb21;
} ggml_metal_kargs_mul_mm_id_map0;
typedef struct {
int32_t ne20; // n_expert_used
int32_t neh0;
int32_t neh1;
uint64_t nbh1;
uint64_t nbh2;
int32_t ne0;
uint64_t nb1;
uint64_t nb2;
} ggml_metal_kargs_mul_mm_id_map1;
typedef struct {
int32_t ne00;
int32_t ne02;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t neh12;
uint64_t nbh10;
uint64_t nbh11;
uint64_t nbh12;
uint64_t nbh13;
int32_t neh0;
int32_t neh1;
int32_t ne11;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne20;
int32_t ne21;
int32_t ne0;
int32_t ne1;
int16_t r2;
int16_t r3;
} ggml_metal_kargs_mul_mm_id;

View File

@ -93,35 +93,37 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
if (ctx->mtl_device == nil) {
ctx->mtl_device = MTLCreateSystemDefaultDevice();
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
if (ctx->mtl_device) {
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
#endif
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
#if defined(GGML_METAL_USE_BF16)
ctx->use_bfloat = ctx->has_bfloat;
ctx->use_bfloat = ctx->has_bfloat;
#else
ctx->use_bfloat = false;
ctx->use_bfloat = false;
#endif
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
{
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
ctx->debug_fusion = val ? atoi(val) : 0;
{
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
ctx->debug_fusion = val ? atoi(val) : 0;
}
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
ctx->max_size = ctx->mtl_device.maxBufferLength;
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
}
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
ctx->max_size = ctx->mtl_device.maxBufferLength;
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
}
ctx->mtl_device_ref_count++;
@ -289,6 +291,10 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
@ -396,8 +402,12 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
@ -443,6 +453,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
@ -452,6 +463,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
@ -461,6 +473,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
@ -470,6 +483,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
@ -479,6 +493,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
@ -488,6 +503,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
@ -497,6 +513,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
@ -506,6 +523,13 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
@ -555,6 +579,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE,
GGML_METAL_KERNEL_TYPE_SET_I32,
GGML_METAL_KERNEL_TYPE_SET_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
@ -1304,6 +1329,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2, mul_mv_ext_f32_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3, mul_mv_ext_f32_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4, mul_mv_ext_f32_f32_r1_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5, mul_mv_ext_f32_f32_r1_5, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
@ -1412,8 +1441,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1, mul_mm_id_map0_f16_ne20_1, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2, mul_mm_id_map0_f16_ne20_2, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
@ -1459,6 +1492,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40, flash_attn_ext_f16_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
@ -1468,6 +1502,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40, flash_attn_ext_bf16_h40, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
@ -1477,6 +1512,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40, flash_attn_ext_q4_0_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
@ -1486,6 +1522,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40, flash_attn_ext_q4_1_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
@ -1495,6 +1532,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40, flash_attn_ext_q5_0_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
@ -1504,6 +1542,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40, flash_attn_ext_q5_1_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
@ -1513,6 +1552,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40, flash_attn_ext_q8_0_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
@ -1522,6 +1562,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40, flash_attn_ext_vec_f16_h40, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40, flash_attn_ext_vec_bf16_h40, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40, flash_attn_ext_vec_q4_0_h40, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40, flash_attn_ext_vec_q4_1_h40, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40, flash_attn_ext_vec_q5_0_h40, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40, flash_attn_ext_vec_q5_1_h40, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40, flash_attn_ext_vec_q8_0_h40, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
@ -1571,6 +1618,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, flash_attn_ext_reduce, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
@ -1846,7 +1894,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_ROPE:
return true;
case GGML_OP_IM2COL:
return op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
case GGML_OP_POOL_1D:
return false;
case GGML_OP_UPSCALE:
@ -3347,15 +3395,16 @@ static int ggml_metal_encode_node(
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel
const int ne11_mm_min = 4;
const int ne11_mm_min = 8;
// first try to use small-batch mat-mv kernels
// these should be efficient for BS [2, ~8]
if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
if (src1t == GGML_TYPE_F32 && (ne00%128 == 0) &&
(
(
(
src0t == GGML_TYPE_F16 || // TODO: helper function
src0t == GGML_TYPE_F32 || // TODO: helper function
src0t == GGML_TYPE_F16 ||
src0t == GGML_TYPE_Q4_0 ||
src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q5_0 ||
@ -3383,7 +3432,17 @@ static int ggml_metal_encode_node(
// values and there can be some tail effects when nsg is high. need to confirm this
//
const int nsg = 2; // num simdgroups per threadgroup
const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
// num threads along row per simdgroup
int nxpsg = 0;
if (ne00 % 256 == 0 && ne11 < 3) {
nxpsg = 16;
} else if (ne00 % 128 == 0) {
nxpsg = 8;
} else {
nxpsg = 4;
}
const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
int r1ptg = 4; // num src1 rows per threadgroup
@ -3406,6 +3465,14 @@ static int ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = nil;
switch (src0->type) {
case GGML_TYPE_F32:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2].pipeline; break;
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3].pipeline; break;
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4].pipeline; break;
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5].pipeline; break;
default: GGML_ABORT("not implemented");
} break;
case GGML_TYPE_F16:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
@ -3560,7 +3627,7 @@ static int ggml_metal_encode_node(
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
@ -3878,38 +3945,6 @@ static int ggml_metal_encode_node(
default: break;
}
const int64_t neh10 = ne10; // n_embd
const int64_t neh11 = ne21; // n_tokens
const int64_t neh12 = ne02; // n_expert
const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16);
const uint64_t nbh11 = nbh10*neh10;
const uint64_t nbh12 = nbh11*neh11;
const uint64_t nbh13 = nbh12*neh12;
const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12;
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
if (!h_src1) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
return 0;
}
const int64_t neh0 = ne0;
const int64_t neh1 = ne21;
const int64_t neh2 = ne02;
const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32);
const uint64_t nbh1 = nbh0*neh0;
const uint64_t nbh2 = nbh1*neh1;
//const uint64_t nbh3 = nbh2*neh2;
const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2;
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
if (!h_dst) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
return 0;
}
// tokens per expert
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
@ -3919,8 +3954,8 @@ static int ggml_metal_encode_node(
}
// id map
// [n_expert_used, n_tokens]
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21;
// [n_tokens, n_expert]
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
if (!h_ids) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
@ -3928,32 +3963,45 @@ static int ggml_metal_encode_node(
}
{
const int nth = MIN(1024, ne10/4);
ggml_metal_kargs_mul_mm_id_map0 args = {
ne02,
ne10,
ne11, // n_expert_used (bcast)
ne11, // n_expert_used (bcast)
nb11,
nb12,
neh11, // n_tokens
nbh11,
ne20, // n_expert_used
ne21, // n_tokens
ne20, // n_expert_used
nb21,
};
id<MTLComputePipelineState> pipeline = nil;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
pipeline = nil;
switch (ne20) {
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1 ].pipeline; break;
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2 ].pipeline; break;
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break;
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break;
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break;
case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break;
default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20);
}
GGML_ASSERT(ne02 <= (int) pipeline.maxTotalThreadsPerThreadgroup);
const size_t smem = ne02*ne20*sizeof(uint16_t);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
[encoder setBuffer: h_src1 offset:0 atIndex:3];
[encoder setBuffer: h_tpe offset:0 atIndex:4];
[encoder setBuffer: h_ids offset:0 atIndex:5];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:1];
[encoder setBuffer: h_tpe offset:0 atIndex:2];
[encoder setBuffer: h_ids offset:0 atIndex:3];
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
}
{
@ -3992,13 +4040,15 @@ static int ggml_metal_encode_node(
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.neh12 =*/ neh12,
/*.nbh10 =*/ nbh10,
/*.nbh11 =*/ nbh11,
/*.nbh12 =*/ nbh12,
/*.nbh13 =*/ nbh13,
/*.neh0 =*/ neh0,
/*.neh1 =*/ neh1,
/*.ne11 =*/ ne11, // n_expert_used (bcast)
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne20 =*/ ne20, // n_expert_used
/*.ne21 =*/ ne21, // n_tokens
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.r2 =*/ r2,
/*.r3 =*/ r3,
};
@ -4006,42 +4056,14 @@ static int ggml_metal_encode_node(
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer: h_src1 offset:0 atIndex:2];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer: h_tpe offset:0 atIndex:3];
[encoder setBuffer: h_dst offset:0 atIndex:4];
[encoder setBuffer: h_ids offset:0 atIndex:4];
[encoder setBuffer:id_dst offset:offs_dst atIndex:5];
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
}
{
GGML_ASSERT(ne0 % 4 == 0);
const int nth = MIN(1024, ne0/4);
ggml_metal_kargs_mul_mm_id_map1 args = {
ne20, // n_expert_used
neh0,
neh1,
nbh1,
nbh2,
ne0,
nb1,
nb2,
};
id<MTLComputePipelineState> pipeline = nil;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer: h_dst offset:0 atIndex:1];
[encoder setBuffer: h_ids offset:0 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
[encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
} else {
id<MTLComputePipelineState> pipeline = nil;
@ -4701,7 +4723,6 @@ static int ggml_metal_encode_node(
} break;
case GGML_OP_IM2COL:
{
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
@ -5130,6 +5151,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
} else {
switch (ne00) {
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
@ -5154,6 +5176,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
} else {
switch (ne00) {
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
@ -5178,6 +5201,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
} else {
switch (ne00) {
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
@ -5202,6 +5226,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
} else {
switch (ne00) {
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
@ -5226,6 +5251,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
} else {
switch (ne00) {
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
@ -5250,6 +5276,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
} else {
switch (ne00) {
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
@ -5274,6 +5301,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
} else {
switch (ne00) {
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
@ -5301,6 +5329,24 @@ static int ggml_metal_encode_node(
use_vec_kernel = true;
switch (ne00) {
case 40:
{
switch (src1->type) {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40].pipeline; break;
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
GGML_LOG_ERROR("add template specialization for this type\n");
GGML_ABORT("add template specialization for this type");
}
}
} break;
case 64:
{
switch (src1->type) {
@ -5465,6 +5511,7 @@ static int ggml_metal_encode_node(
/*.nb33 =*/ nb33,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.scale =*/ scale,
/*.max_bias =*/ max_bias,
/*.m0 =*/ m0,
@ -5488,7 +5535,6 @@ static int ggml_metal_encode_node(
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
if (!use_vec_kernel) {
// half8x8 kernel
@ -5514,7 +5560,7 @@ static int ggml_metal_encode_node(
while (true) {
const size_t smem = FATTN_SMEM(nsgmax);
if (smem > device.maxThreadgroupMemoryLength) {
if (smem > device.maxThreadgroupMemoryLength/2) {
break;
}
nsgmax *= 2;
@ -5526,15 +5572,18 @@ static int ggml_metal_encode_node(
const size_t smem = FATTN_SMEM(nsg);
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:smem atIndex:0];
#undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
#undef FATTN_SMEM
} else {
// half4x4 kernel
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
const int64_t nkpsg = 1*ncpsg; // TODO: make adjustable
GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 1 == 0);
@ -5544,15 +5593,17 @@ static int ggml_metal_encode_node(
// for each query, we load it as f16 in shared memory (ne00)
// and store the soft_max values and the mask
//
// ne00*(nsg)
// ne20*(nsg)
// each simdgroup has a full f32 head vector in shared mem to accumulate results
//
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
//#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)))*(sizeof(float)/2), 16))
int64_t nsgmax = 2;
while (true) {
const size_t smem = FATTN_SMEM(nsgmax);
if (smem > device.maxThreadgroupMemoryLength) {
// avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
if (smem > device.maxThreadgroupMemoryLength/2) {
break;
}
nsgmax *= 2;
@ -5560,7 +5611,7 @@ static int ggml_metal_encode_node(
nsgmax /= 2;
// simdgroups per threadgroup (a.k.a. warps)
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
int64_t nsg = 1;
while (nsg <= nsgt) {
@ -5568,13 +5619,74 @@ static int ggml_metal_encode_node(
}
nsg /= 2;
const size_t smem = FATTN_SMEM(nsg);
// workgroups
// each workgroup handles nsg*nkpsg cache values
uint16_t nwg = 1;
if (4*nsg*nkpsg >= ne11) {
const size_t smem = FATTN_SMEM(nsg);
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:smem atIndex:0];
//printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
// using 1 workgroup -> write the result directly into dst
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
[encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} else {
nwg = 32;
nsg = MIN(4, nsg);
const size_t smem = FATTN_SMEM(nsg);
//printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
// sanity checks
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
const int32_t nrows = ne1*ne2*ne3;
// temp buffer for writing the results from each workgroup
// - ne20: the size of the head vector
// - + 2: the S and M values for each intermediate result
const size_t s_tmp = ggml_type_size(GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2));
id<MTLBuffer> h_tmp = ggml_metal_mem_pool_alloc(mem_pool, s_tmp);
if (!h_tmp) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tmp);
return 0;
}
//printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
//printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
[encoder setBuffer:h_tmp offset:0 atIndex:6];
[encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
// reduce the results from the workgroups
{
ggml_metal_kargs_flash_attn_ext_reduce args0 = {
nrows,
ne20,
};
id<MTLComputePipelineState> pipeline0 = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline;
[encoder setComputePipelineState:pipeline0];
[encoder setBytes:&args0 length:sizeof(args0) atIndex:0];
[encoder setBuffer:h_tmp offset:0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
//printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*32, 1, 1)];
}
}
#undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
}
} break;
case GGML_OP_DUP:

View File

@ -68,6 +68,11 @@ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg)
reg = (type4x4)(*src);
}
template <typename type4>
void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) {
reg = (type4)(*src);
}
template <typename type4x4>
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
@ -974,9 +979,16 @@ kernel void kernel_mul(
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = i0%args.ne10;
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
if (args.ne10 == 1) {
const float x = *((device float *)(src1_ptr));
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
}
} else {
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = i0%args.ne10;
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
}
}
}
@ -1000,9 +1012,16 @@ kernel void kernel_div(
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = i0%args.ne10;
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
if (args.ne10 == 1) {
const float x = 1.0f / *((device float *)(src1_ptr));
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
}
} else {
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = i0%args.ne10;
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
}
}
}
@ -1964,14 +1983,15 @@ kernel void kernel_ssm_scan_f32(
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
const int64_t i = i0 + i1*nc;
const int64_t g = ir / (nh / ng); // repeat_interleave
float s0 = s0_buff[i];
float s = s_buff[i];
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
for (int64_t i2 = 0; i2 < n_t; ++i2) {
@ -2079,14 +2099,15 @@ kernel void kernel_ssm_scan_f32_group(
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
const int64_t i = i0 + i1*nc;
const int64_t g = ir / (nh / ng); // repeat_interleave
float s0 = s0_buff[i];
float s = s_buff[i];
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
for (int64_t i2 = 0; i2 < n_t; ++i2) {
@ -3001,7 +3022,6 @@ void kernel_mul_mv_ext_q4_f32_impl(
#pragma unroll(r1ptg)
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);
}
}
@ -3186,6 +3206,11 @@ kernel void kernel_mul_mv_ext_q4x4_f32_disp(
typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t;
template [[host_name("kernel_mul_mv_ext_f32_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, float4, 4, dequantize_f32_t4>;
template [[host_name("kernel_mul_mv_ext_f32_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, float4, 4, dequantize_f32_t4>;
template [[host_name("kernel_mul_mv_ext_f32_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, float4, 4, dequantize_f32_t4>;
template [[host_name("kernel_mul_mv_ext_f32_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, float4, 4, dequantize_f32_t4>;
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>;
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>;
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
@ -4663,6 +4688,7 @@ kernel void kernel_flash_attn_ext(
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
template [[host_name("kernel_flash_attn_ext_f16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
@ -4674,6 +4700,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_bf16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
@ -4685,6 +4712,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
#endif
template [[host_name("kernel_flash_attn_ext_q4_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
@ -4695,6 +4723,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q4_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
@ -4705,6 +4734,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
@ -4715,6 +4745,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
@ -4725,6 +4756,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q8_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
@ -4765,14 +4797,16 @@ kernel void kernel_flash_attn_ext_vec(
device const char * mask,
device const char * sinks,
device char * dst,
constant uint16_t & nwg,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups
const short iwg = tgpig[2]%nwg;
const int iq3 = tgpig[2];
const int iq3 = tgpig[2]/nwg;
const int iq2 = tgpig[1];
const int iq1 = tgpig[0];
@ -4851,7 +4885,7 @@ kernel void kernel_flash_attn_ext_vec(
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
for (int ic0 = (int) iwg*C*nsg; ic0 < args.ne11; ic0 += (int) nwg*C*nsg) {
const int ic = ic0 + C*sgitg;
if (ic >= args.ne11) {
break;
@ -4981,7 +5015,7 @@ kernel void kernel_flash_attn_ext_vec(
}
}
if (sinks != q && sgitg == 0) {
if (sinks != q && sgitg == 0 && iwg == 0) {
const float m = M;
const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
@ -5090,14 +5124,25 @@ kernel void kernel_flash_attn_ext_vec(
threadgroup_barrier(mem_flags::mem_threadgroup);
}
device float4 * dst4 = (device float4 *) dst;
// final rescale with 1/S and store to global memory
if (sgitg == 0) {
const float S = ss[0];
const int64_t nrows = args.ne3*args.ne2*args.ne1;
const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1;
device float4 * dst4 = (device float4 *) dst;
device float * dst1 = (device float *) dst + nrows*DV*nwg; // the S and M are stored after the results
const float S = nwg == 1 ? 1.0f/ss[0] : 1.0f;
// interleave the workgroup data
for (short i = tiisg; i < DV4; i += NW) {
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S;
dst4[rid*DV4*nwg + nwg*i + iwg] = (float4) sr4[i]*S;
}
// store S and M
if (nwg > 1 && tiisg == 0) {
dst1[rid*(2*nwg) + 2*iwg + 0] = ss[0];
dst1[rid*(2*nwg) + 2*iwg + 1] = ss[1];
}
}
}
@ -5115,6 +5160,16 @@ kernel void kernel_flash_attn_ext_vec(
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
template [[host_name("kernel_flash_attn_ext_vec_f16_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 40, 40, 8>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 40, 40, 8>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 40, 40, 8>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 40, 40, 8>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 40, 40, 8>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 40, 40, 8>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 40, 40, 8>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 8>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 8>;
@ -5187,6 +5242,41 @@ template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flas
#undef FA_TYPES
kernel void kernel_flash_attn_ext_reduce(
constant ggml_metal_kargs_flash_attn_ext_reduce & args,
device const char * htmp,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const uint64_t rid = tgpig;
const short nwg = 32;
const short iwg = tiisg;
const short DV = args.ne20;
const short DV4 = DV/4;
device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*nwg;
device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*nwg;
device float4 * dst4 = (device float4 *) dst + rid*DV4;
float S = ss[rid*(2*nwg) + 2*iwg + 0];
float M = ss[rid*(2*nwg) + 2*iwg + 1];
const float m = simd_max(M);
const float ms = exp(M - m);
S = 1.0f/simd_sum(S*ms);
for (int i = sgitg; i < DV4; i += nwg) {
const float4 v = simd_sum(htmp4[i*nwg + iwg]*ms);
if (iwg == 0) {
dst4[i] = v*S;
}
}
}
template<typename T>
kernel void kernel_set(
constant ggml_metal_kargs_set & args,
@ -7474,97 +7564,81 @@ kernel void kernel_mul_mm(
}
}
template<typename T4>
template<short ne20> // n_expert_used
kernel void kernel_mul_mm_id_map0(
constant ggml_metal_kargs_mul_mm_id_map0 & args,
device const char * src1,
device const char * src2,
device char * hsrc1,
device char * htpe,
device char * hids,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int ide = tgpig[0]; // expert id
threadgroup char * shmem [[threadgroup(0)]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort ntg[[threads_per_threadgroup]]) {
const short ide = tpitg; // expert id
int n_all = 0;
uint32_t n_all = 0;
device int32_t * ids_i32 = (device int32_t *) (hids);
device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens
device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
if (i21 + tpitg < args.ne21) {
device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
if (src2_i32[i20] != ide) {
continue;
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
#pragma unroll(ne20)
for (short i20 = 0; i20 < ne20; i20++) {
sids[i20] = src2_i32[i20];
}
device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]);
}
if (tpitg.x == 0) {
ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
}
++n_all;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (short t = 0; t < ntg; t++) {
if (i21 + t >= args.ne21) {
break;
}
threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
short sel = 0;
#pragma unroll(ne20)
for (short i20 = 0; i20 < ne20; i20++) {
sel += (sids[i20] == ide)*(i20 + 1);
}
ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
n_all += sel > 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
if (tpitg.x == 0) {
device int32_t * tpe_i32 = (device int32_t *) (htpe);
tpe_i32[ide] = n_all;
}
device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
tpe_u32[ide] = n_all;
}
typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
template<typename T>
kernel void kernel_mul_mm_id_map1(
constant ggml_metal_kargs_mul_mm_id_map1 & args,
device const char * hdst,
device const char * hids,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i20 = tgpig[0]; // used expert
const int i21 = tgpig[1]; // token
device const int32_t * ids_i32 = (device const int32_t *) (hids);
device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2);
const int id = ids_i32[i21*args.ne20 + i20];
const int ide = id / args.neh1;
const int idt = id % args.neh1;
device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2);
for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) {
dst_f32x4[i0] = hdst_f32x4[i0];
}
}
typedef decltype(kernel_mul_mm_id_map1<float>) kernel_mul_mm_id_map1_t;
template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1<float>;
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
kernel void kernel_mul_mm_id(
constant ggml_metal_kargs_mul_mm_id & args,
device const char * src0,
device const char * src1,
device const char * tpe,
device const char * htpe,
device const char * hids,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup T * sa = (threadgroup T *)(shmem);
@ -7572,19 +7646,20 @@ kernel void kernel_mul_mm_id(
const int r0 = tgpig.y;
const int r1 = tgpig.x;
const int im = tgpig.z;
const int im = tgpig.z; // expert
device const int32_t * tpe_i32 = (device const int32_t *) (tpe);
device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
device const int32_t * ids_i32 = (device const int32_t *) (hids);
const int neh1 = tpe_i32[im];
const int32_t neh1 = tpe_u32[im];
if (r1*BLOCK_SIZE_N >= neh1) {
return;
}
// if this block is of 64x32 shape or smaller
const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
// a thread shouldn't load data outside of the matrix
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
@ -7600,20 +7675,23 @@ kernel void kernel_mul_mm_id(
short il = (tiitg % THREAD_PER_ROW);
const int i12 = im%args.neh12;
const int i13 = im/args.neh12;
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col];
const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const short i11 = (id % args.ne20) % args.ne11;
const short i12 = (id / args.ne20);
const short i13 = 0;
const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
const short offset1 = il/nl;
device const block_q * x = (device const block_q *)(src0
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
device const half * y = (device const half *)(src1
+ args.nbh13*i13
+ args.nbh12*i12
+ args.nbh11*(r1*BLOCK_SIZE_N + thread_col)
+ args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
device const float * y = (device const float *)(src1
+ args.nb13*i13
+ args.nb12*i12
+ args.nb11*i11
+ args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
// load data and store to threadgroup memory
@ -7629,7 +7707,7 @@ kernel void kernel_mul_mm_id(
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
}
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y);
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device float2x4 *) y));
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
@ -7665,43 +7743,38 @@ kernel void kernel_mul_mm_id(
}
}
if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) {
device float * C = (device float *) dst +
(BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
(BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0);
}
} else {
// block is smaller than 64x32, we should avoid writing data outside of the matrix
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float * temp_str = ((threadgroup float *) shmem) \
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
threadgroup float * temp_str = ((threadgroup float *) shmem) \
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
#pragma unroll(8)
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (short j = sgitg; j < n_cols; j += 4) {
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
const short ide = id % args.ne20;
const short idt = id / args.ne20;
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0;
device float4 * D4 = (device float4 *) D;
threadgroup float * C = (threadgroup float *) shmem + (j*BLOCK_SIZE_M);
threadgroup float4 * C4 = (threadgroup float4 *) C;
int i = tiisg;
for (; i < n_rows/4; i += 32) {
*(D4 + i) = *(C4 + i);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0;
device float4 * D4 = (device float4 *) D;
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
threadgroup float4 * C4 = (threadgroup float4 *) C;
int i = 0;
for (; i < n_rows/4; i++) {
*(D4 + i) = *(C4 + i);
}
i *= 4;
for (; i < n_rows; i++) {
*(D + i) = *(C + i);
}
}
i = (4*(n_rows/4)) + tiisg;
for (; i < n_rows; i += 32) {
*(D + i) = *(C + i);
}
}
}

View File

@ -420,9 +420,9 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_clamp;
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
cl_kernel kernel_norm;
cl_kernel kernel_norm, kernel_norm_mul_add;
cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
cl_kernel kernel_group_norm;
cl_kernel kernel_group_norm, kernel_group_norm_mul_add;
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
cl_kernel kernel_soft_max, kernel_soft_max_4;
cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
@ -1161,7 +1161,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
backend_ctx->program_norm =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err));
CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err));
CL_CHECK((backend_ctx->kernel_norm_mul_add = clCreateKernel(backend_ctx->program_norm, "kernel_norm_mul_add", &err), err));
GGML_LOG_CONT(".");
}
@ -1487,7 +1488,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
backend_ctx->program_group_norm =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err));
CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err));
CL_CHECK((backend_ctx->kernel_group_norm_mul_add = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm_mul_add", &err), err));
GGML_LOG_CONT(".");
}
@ -2498,12 +2500,47 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
return false;
}
} else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) {
const ggml_tensor *norm = cgraph->nodes[node_idx];
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
const ggml_tensor *add = cgraph->nodes[node_idx+2];
const ggml_tensor *w = mul->src[0] == norm ? mul->src[1] : mul->src[0];
const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0];
// norm fusion only supports F32
if (norm->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) {
return false;
}
if (norm->src[0]->ne[0] % 4 != 0) {
return false;
}
if (!ggml_is_contiguous(norm->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) {
return false;
}
} else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_GROUP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) {
const ggml_tensor *gn = cgraph->nodes[node_idx];
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
const ggml_tensor *add = cgraph->nodes[node_idx+2];
const ggml_tensor *w = mul->src[0] == gn ? mul->src[1] : mul->src[0];
const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0];
if (gn->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) {
return false;
}
if (!ggml_is_contiguous(gn->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) {
return false;
}
}
return true;
}
static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor);
static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);
static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);
static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
@ -2520,6 +2557,16 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
continue;
}
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
ggml_opencl_op_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
i += 2;
continue;
}
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_GROUP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
ggml_opencl_op_group_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
i += 2;
continue;
}
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]);
i++;
@ -2647,8 +2694,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SOFT_MAX:
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
return true;
case GGML_OP_RMS_NORM:
return op->ne[0] % 4 == 0 && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_REPEAT:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
case GGML_OP_PAD:
@ -5038,6 +5086,140 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor *
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {
GGML_ASSERT(norm_tensor && mul_tensor && add_tensor);
const ggml_tensor * src0 = norm_tensor->src[0];
const ggml_tensor * src1 = mul_tensor->src[0] == norm_tensor ? mul_tensor->src[1] : mul_tensor->src[0];
const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0];
const ggml_tensor * dst = add_tensor;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offset2 = extra2->offset + src2->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
float eps;
memcpy(&eps, norm_tensor->op_params, sizeof(float));
const int ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3];
const cl_ulong nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3];
const int ne10 = src1->ne[0], ne11 = src1->ne[1], ne12 = src1->ne[2], ne13 = src1->ne[3];
const cl_ulong nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3];
const int ne20 = src2->ne[0], ne21 = src2->ne[1], ne22 = src2->ne[2], ne23 = src2->ne[3];
const cl_ulong nb21 = src2->nb[1], nb22 = src2->nb[2], nb23 = src2->nb[3];
const cl_ulong nbd1 = dst->nb[1], nbd2 = dst->nb[2], nbd3 = dst->nb[3];
size_t sgs;
if (backend_ctx->gpu_family == ADRENO) sgs = 64;
else if (backend_ctx->gpu_family == INTEL) sgs = 32;
else GGML_ASSERT(false && "Unsupported GPU");
cl_kernel kernel = backend_ctx->kernel_norm_mul_add;
int nth = sgs;
int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
while (nth < ne00/4 && nth < max_workgroup_size) nth *= 2;
nth = MIN(nth, max_workgroup_size);
nth = MIN(nth, ne00/4);
size_t gws[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t lws[] = {(size_t)nth, 1, 1};
size_t num_subgroups = (nth + sgs - 1) / sgs;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne20));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne21));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne22));
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne23));
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb21));
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb22));
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb23));
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nbd1));
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(cl_ulong), &nbd2));
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_ulong), &nbd3));
CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &eps));
CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_float2) * num_subgroups, NULL));
backend_ctx->enqueue_ndrange_kernel(kernel, 3, gws, lws, dst);
}
static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {
GGML_ASSERT(gn_tensor && mul_tensor && add_tensor);
const ggml_tensor * src0 = gn_tensor->src[0];
const ggml_tensor * src1 = mul_tensor->src[0] == gn_tensor ? mul_tensor->src[1] : mul_tensor->src[0];
const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0];
const ggml_tensor * dst = add_tensor;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offset2 = extra2->offset + src2->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
int groups;
float eps;
memcpy(&groups, gn_tensor->op_params, sizeof(int));
memcpy(&eps, (char *)gn_tensor->op_params + sizeof(int), sizeof(float));
cl_kernel kernel = backend_ctx->kernel_group_norm_mul_add;
int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
int ne = ggml_nelements(src0);
int group_size = ne / groups;
size_t lws[] = { (size_t)MIN(max_workgroup_size, group_size) };
size_t gws[] = { (size_t)groups * lws[0] };
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &group_size));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &eps));
backend_ctx->enqueue_ndrange_kernel(kernel, 1, gws, lws, dst);
}
static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);

View File

@ -70,3 +70,52 @@ kernel void kernel_group_norm(
dst[j] *= scale;
}
}
//------------------------------------------------------------------------------
// group_norm_mul_add
//------------------------------------------------------------------------------
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_32
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_group_norm_mul_add(
global float * src0, ulong offset0,
global float * src1, ulong offset1,
global float * src2, ulong offset2,
global float * dst, ulong offsetd,
int ne,
int group_size,
float eps
) {
src0 = (global float *)((global char *)src0 + offset0);
src1 = (global float *)((global char *)src1 + offset1);
src2 = (global float *)((global char *)src2 + offset2);
dst = (global float *)((global char *)dst + offsetd);
int start = get_group_id(0) * group_size;
int end = start + group_size;
if (end > ne) {
end = ne;
}
float sum = 0.0f;
float sum_sq = 0.0f;
for (int j = start + get_local_id(0); j < end; j += get_local_size(0)) {
float val = src0[j];
sum += val;
sum_sq += val*val;
}
sum = sub_group_reduce_add(sum);
sum_sq = sub_group_reduce_add(sum_sq);
const float mean = sum / group_size;
const float var = sum_sq / group_size - mean * mean;
const float scale = rsqrt(var + eps);
for (int j = start + get_local_id(0); j < end; j += get_local_size(0)) {
dst[j] = ((src0[j] - mean) * scale) * src1[j] + src2[j];
}
}

View File

@ -79,3 +79,83 @@ kernel void kernel_norm(
y[i00] = y[i00] * scale;
}
}
//------------------------------------------------------------------------------
// norm_mul_add
//------------------------------------------------------------------------------
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_32
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_norm_mul_add(
global char * src0_ptr, ulong src0_offset,
global char * src1_ptr, ulong src1_offset,
global char * src2_ptr, ulong src2_offset,
global char * dst_ptr, ulong dst_offset,
int ne00, int ne01, int ne02, int ne03,
ulong nb01, ulong nb02, ulong nb03,
int ne10, int ne11, int ne12, int ne13,
ulong nb11, ulong nb12, ulong nb13,
int ne20, int ne21, int ne22, int ne23,
ulong nb21, ulong nb22, ulong nb23,
ulong nbd1, ulong nbd2, ulong nbd3,
float eps,
local float2 * sums
) {
const int i03 = get_group_id(2);
const int i02 = get_group_id(1);
const int i01 = get_group_id(0);
global float4 * x = (global float4 *)(src0_ptr + src0_offset + i01*nb01 + i02*nb02 + i03*nb03);
global float4 * w = (global float4 *)(src1_ptr + src1_offset + (i01%ne11)*nb11 + (i02%ne12)*nb12 + (i03%ne13)*nb13);
global float4 * b = (global float4 *)(src2_ptr + src2_offset + (i01%ne21)*nb21 + (i02%ne22)*nb22 + (i03%ne23)*nb23);
global float4 * y = (global float4 *)(dst_ptr + dst_offset + i01*nbd1 + i02*nbd2 + i03*nbd3);
float p_sum = 0.0f;
float p_sum_sq = 0.0f;
const int n_chunks = ne00 / 4;
for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) {
float4 val = x[i00];
p_sum += val.x + val.y + val.z + val.w;
p_sum_sq += dot(val, val);
}
p_sum = sub_group_reduce_add(p_sum);
p_sum_sq = sub_group_reduce_add(p_sum_sq);
if (get_sub_group_local_id() == 0) {
sums[get_sub_group_id()] = (float2)(p_sum, p_sum_sq);
}
barrier(CLK_LOCAL_MEM_FENCE);
if (get_local_id(0) == 0) {
float sum = 0.0f;
float sum_sq = 0.0f;
for (uint i = 0; i < get_num_sub_groups(); ++i) {
float2 s = sums[i];
sum += s.x;
sum_sq += s.y;
}
const float inv_ne00 = 1.0f / (float)ne00;
const float mean = sum * inv_ne00;
const float variance = mad(-mean, mean, sum_sq * inv_ne00);
sums[0] = (float2)(mean, rsqrt(variance + eps));
}
barrier(CLK_LOCAL_MEM_FENCE);
const float2 mean_scale = sums[0];
const float mean = mean_scale.x;
const float scale = mean_scale.y;
const float neg_mean_scale = -mean * scale;
for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) {
const int w_idx = ne10 > 1 ? i00 : 0;
const int b_idx = ne20 > 1 ? i00 : 0;
const float4 norm_x = mad(x[i00], (float4)scale, (float4)neg_mean_scale);
y[i00] = mad(norm_x, w[w_idx], b[b_idx]);
}
}

View File

@ -4364,11 +4364,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
#endif
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
return true;
case GGML_OP_L2_NORM:
case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_RMS_NORM:
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
case GGML_OP_SCALE:
return true;
case GGML_OP_CONT:

View File

@ -2090,10 +2090,11 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
const uint32_t warps = warptile[0] / warptile[10];
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0;
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0;
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
@ -2183,7 +2184,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t mul_mat_subgroup_size_32 = std::max(mul_mat_subgroup_size, 32u);
const bool subgroup_min_size_16 = (!device->subgroup_size_control && device->subgroup_size >= 16) ||
(device->subgroup_size_control && device->subgroup_min_size <= 16 && device->subgroup_max_size >= 16);
(device->subgroup_size_control && device->subgroup_max_size >= 16);
// mulmat
std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
@ -5799,11 +5800,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
ggml_vk_sync_buffers(ctx, subctx);
}
}
if (y_non_contig || quantize_y) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
}
if (x_non_contig) {
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
@ -5815,6 +5811,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (y_non_contig) {
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
@ -5823,6 +5822,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (quantize_y) {
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
@ -6007,11 +6009,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
ggml_vk_sync_buffers(ctx, subctx);
}
}
if (y_non_contig) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
}
if (x_non_contig) {
GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
@ -6021,6 +6018,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
@ -6288,7 +6288,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t nei0 = ids->ne[0];
const uint64_t nei1 = ids->ne[1];
GGML_ASSERT(nei0 * nei1 <= 4096);
const uint32_t nbi1 = ids->nb[1];
const uint32_t nbi2 = ids->nb[2];
@ -6454,11 +6453,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
ggml_vk_sync_buffers(ctx, subctx);
}
}
if (y_non_contig) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
}
if (x_non_contig) {
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
@ -6471,6 +6465,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (y_non_contig) {
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
@ -6668,11 +6665,6 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
ggml_vk_sync_buffers(ctx, subctx);
}
}
if (y_non_contig) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
}
if (x_non_contig) {
GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
@ -6682,6 +6674,9 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
@ -6728,37 +6723,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
} else {
// Split based on number of ids, to fit in shared memory
const uint32_t nei0 = (uint32_t)src2->ne[0];
const uint32_t nei1 = (uint32_t)src2->ne[1];
GGML_ASSERT(nei0 <= 4096);
const uint32_t split_size = std::min(nei1, 4096u / nei0);
if (split_size == nei1) {
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
} else {
ggml_tensor src1_copy = *src1;
ggml_tensor src2_copy = *src2;
ggml_tensor dst_copy = *dst;
for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
src1_copy.ne[2] = n_tokens;
src2_copy.ne[1] = n_tokens;
dst_copy.ne[2] = n_tokens;
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
// invalidate cached prealloc_y, can't cache based on the copy of the ggml_tensor
ctx->prealloc_y_last_pipeline_used = {};
ctx->prealloc_y_last_tensor_used = nullptr;
}
}
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
}
}

View File

@ -109,13 +109,13 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
#define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[4096];
shared u16vec2 row_ids[BN];
uint _ne1;
#ifdef MUL_MAT_ID_USE_SUBGROUPS
shared uvec4 ballots_sh[NUM_WARPS];
void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint nei0shift = findLSB(p.nei0);
@ -165,11 +165,14 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
barrier();
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx) {
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
}
_ne1 += total;
iter &= 15;
if (_ne1 >= (ic + 1) * BN) {
break;
}
}
barrier();
}
@ -242,16 +245,18 @@ void main() {
#ifdef MUL_MAT_ID
#ifdef MUL_MAT_ID_USE_SUBGROUPS
if (bitCount(p.nei0) == 1) {
load_row_ids(expert_idx, true);
load_row_ids(expert_idx, true, ic);
} else {
load_row_ids(expert_idx, false);
load_row_ids(expert_idx, false, ic);
}
#else
_ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
row_ids[_ne1] = u16vec2(ii0, ii1);
if (_ne1 >= ic * BN) {
row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
}
_ne1++;
}
}
@ -797,7 +802,7 @@ void main() {
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
#if LOAD_VEC_B == 8
#ifdef MUL_MAT_ID
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
const u16vec2 row_idx = row_ids[loadc_b + l];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
#else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
@ -813,7 +818,7 @@ void main() {
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
#elif LOAD_VEC_B == 4
#ifdef MUL_MAT_ID
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
const u16vec2 row_idx = row_ids[loadc_b + l];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
#else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
@ -832,7 +837,7 @@ void main() {
#else
const uint row_i = ic * BN + loadc_b + l;
if (row_i < _ne1 && block + loadr_b < end_k) {
const u16vec2 row_idx = row_ids[row_i];
const u16vec2 row_idx = row_ids[loadc_b + l];
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
} else {
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
@ -903,7 +908,7 @@ void main() {
const uint row_i = dc + cm_col * TN + col + store_c;
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
const u16vec2 row_idx = row_ids[row_i - ic * BN];
if (dr + cm_row * TM + store_r < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
@ -953,7 +958,7 @@ void main() {
const uint row_i = dc_warp + cc;
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
const u16vec2 row_idx = row_ids[row_i - ic * BN];
#endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
#ifdef MUL_MAT_ID

View File

@ -93,7 +93,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
shared u16vec4 row_ids[4096];
shared u16vec4 row_ids[BN];
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
B_TYPE b[];
@ -111,7 +111,7 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i
return B_TYPE(0.0);
}
const u16vec4 row_idx = row_ids[row_i];
const u16vec4 row_idx = row_ids[row_i & (BN - 1)];
B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
return ret;
@ -123,14 +123,14 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
uint dc = ic * BN + c;
if (dr < p.M && dc < _ne1) {
uint row_i = dc;
uint row_i = c;
const u16vec4 row_idx = row_ids[row_i];
data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
}
return elem;
}
void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint nei0shift = findLSB(p.nei0);
@ -180,11 +180,14 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
barrier();
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx) {
row_ids[_ne1 + idx] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
row_ids[_ne1 + idx - ic * BN] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
}
_ne1 += total;
iter &= 15;
if (_ne1 >= (ic + 1) * BN) {
break;
}
}
barrier();
}
@ -218,9 +221,9 @@ void main() {
#ifdef MUL_MAT_ID
if (bitCount(p.nei0) == 1) {
load_row_ids(expert_idx, true);
load_row_ids(expert_idx, true, ic);
} else {
load_row_ids(expert_idx, false);
load_row_ids(expert_idx, false, ic);
}
// Workgroup has no work

View File

@ -231,8 +231,10 @@ class Keys:
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
class Adapter:
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"
LORA_TASK_NAME = "adapter.lora.task_name"
LORA_PROMPT_PREFIX = "adapter.lora.prompt_prefix"
class IMatrix:
CHUNK_COUNT = "imatrix.chunk_count"
@ -315,6 +317,7 @@ class MODEL_ARCH(IntEnum):
NOMIC_BERT_MOE = auto()
NEO_BERT = auto()
JINA_BERT_V2 = auto()
JINA_BERT_V3 = auto()
BLOOM = auto()
STABLELM = auto()
QWEN = auto()
@ -364,6 +367,7 @@ class MODEL_ARCH(IntEnum):
T5ENCODER = auto()
JAIS = auto()
NEMOTRON = auto()
NEMOTRON_H = auto()
EXAONE = auto()
EXAONE4 = auto()
GRANITE = auto()
@ -647,6 +651,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe",
MODEL_ARCH.NEO_BERT: "neo-bert",
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
MODEL_ARCH.JINA_BERT_V3: "jina-bert-v3",
MODEL_ARCH.BLOOM: "bloom",
MODEL_ARCH.STABLELM: "stablelm",
MODEL_ARCH.QWEN: "qwen",
@ -696,6 +701,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.T5ENCODER: "t5encoder",
MODEL_ARCH.JAIS: "jais",
MODEL_ARCH.NEMOTRON: "nemotron",
MODEL_ARCH.NEMOTRON_H: "nemotron_h",
MODEL_ARCH.EXAONE: "exaone",
MODEL_ARCH.EXAONE4: "exaone4",
MODEL_ARCH.GRANITE: "granite",
@ -1234,6 +1240,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.LAYER_OUT_NORM,
MODEL_TENSOR.CLS,
],
MODEL_ARCH.JINA_BERT_V3: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.TOKEN_TYPES,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_OUT_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.MPT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@ -2281,6 +2299,25 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.NEMOTRON_H: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.SSM_IN,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_OUT,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.EXAONE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@ -2850,6 +2887,7 @@ class VisionProjectorType:
QWEN25O = "qwen2.5o" # omni
VOXTRAL = "voxtral"
LFM2 = "lfm2"
KIMIVL = "kimivl"
# Items here are (block size, type size)

View File

@ -19,6 +19,61 @@ import gguf
logger = logging.getLogger("gguf-convert-endian")
def byteswap_q4_0(tensor, block_offs):
# Each block_q4_0 consists of an f16 delta (scaling factor) followed by 16 int8 quantizations.
# Byte-Swap f16 sized delta field
delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16)
delta.byteswap(inplace=True)
def byteswap_q8_0(tensor, block_offs):
# Each block_q8_0 consists of an f16 delta (scaling factor) followed by 32 int8 quantizations.
# Byte-Swap f16 sized delta field
delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16)
delta.byteswap(inplace=True)
def byteswap_q4_k(tensor, block_offs):
# Each block_q4_k consists of 2 f16 values followed by 140 int8 values.
# Byte-Swap f16 sized fields
delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16)
delta.byteswap(inplace=True)
delta = tensor.data[block_offs + 2:block_offs + 4].view(dtype=np.uint16)
delta.byteswap(inplace=True)
def byteswap_q6_k(tensor, block_offs):
# Each block_q6_k consists of 208 int8 values followed by 1 f16 value.
# Byte-Swap f16 sized field
delta = tensor.data[block_offs + 208:block_offs + 210].view(dtype=np.uint16)
delta.byteswap(inplace=True)
byteswap_tensors = {
gguf.GGMLQuantizationType.Q4_0: {
"block_size": 18, # 18 bytes = <f16 delta scaling factor> + 16 * <int8 quant>
"byteswap_func": byteswap_q4_0,
},
gguf.GGMLQuantizationType.Q8_0: {
"block_size": 34, # 34 bytes = <f16 delta scaling factor> + 32 * <int8 quant>
"byteswap_func": byteswap_q8_0,
},
gguf.GGMLQuantizationType.Q4_K: {
"block_size": 144, # 144 bytes = 2 * <f16 delta scaling factor> + 140 * <int8 quant>
"byteswap_func": byteswap_q4_k,
},
gguf.GGMLQuantizationType.Q6_K: {
"block_size": 210, # 210 bytes = <f16 delta scaling factor> + 208 * <int8 quant>
"byteswap_func": byteswap_q6_k,
},
}
def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None:
file_endian = reader.endianess.name
if reader.byte_order == 'S':
@ -32,13 +87,11 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None
sys.exit(0)
logger.info("* Checking tensors for conversion compatibility")
for tensor in reader.tensors:
if tensor.tensor_type not in (
gguf.GGMLQuantizationType.F32,
gguf.GGMLQuantizationType.F16,
gguf.GGMLQuantizationType.Q8_0,
gguf.GGMLQuantizationType.Q4_K,
gguf.GGMLQuantizationType.Q6_K,
):
if tensor.tensor_type not in byteswap_tensors and \
tensor.tensor_type not in (
gguf.GGMLQuantizationType.F32,
gguf.GGMLQuantizationType.F16,
):
raise ValueError(f"Cannot handle type {tensor.tensor_type.name} for tensor {repr(tensor.name)}")
logger.info(f"* Preparing to convert from {file_endian} to {order}")
if args.dry_run:
@ -72,78 +125,29 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None
part.byteswap(inplace=True)
# Byte-swap tensor data if necessary
if tensor.tensor_type == gguf.GGMLQuantizationType.Q8_0:
# Handle Q8_0 tensor blocks (block_q8_0)
# Specific handling of block_q8_0 is required.
# Each block_q8_0 consists of an f16 delta (scaling factor) followed by 32 int8 quantizations.
block_size = 34 # 34 bytes = <f16 delta scaling factor> + 32 * <int8 quant>
n_blocks = len(tensor.data) // block_size
for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)):
block_offs = block_num * block_size
# Byte-Swap f16 sized delta field
delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16)
delta.byteswap(inplace=True)
# Byte-Swap Q8 weights
if block_num % 100000 == 0:
inner_pbar.set_description(f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]")
elif tensor.tensor_type == gguf.GGMLQuantizationType.Q4_K:
# Handle Q4_K tensor blocks (block_q4_k)
# Specific handling of block_q4_k is required.
# Each block_q4_k consists of 2 f16 values followed by 140 int8 values.
if tensor.tensor_type in byteswap_tensors:
# first flatten structure
oldshape = tensor.data.shape
newshape = 1
for i in tensor.data.shape:
newshape *= i
tensor.data.resize(newshape)
block_size = 144
block_size = byteswap_tensors[tensor.tensor_type]["block_size"]
byteswap_func = byteswap_tensors[tensor.tensor_type]["byteswap_func"]
n_blocks = len(tensor.data) // block_size
for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)):
block_offs = block_num * block_size
# Byte-Swap f16 sized fields
delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16)
delta.byteswap(inplace=True)
byteswap_func(tensor, block_offs)
delta = tensor.data[block_offs + 2:block_offs + 4].view(dtype=np.uint16)
delta.byteswap(inplace=True)
# Byte-Swap
if block_num % 100000 == 0:
inner_pbar.set_description(f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]")
elif tensor.tensor_type == gguf.GGMLQuantizationType.Q6_K:
# Handle Q6_K tensor blocks (block_q6_k)
# Specific handling of block_q6_k is required.
# Each block_q6_k consists of 208 int8 values followed by 1 f16 value.
# first flatten structure
newshape = 1
for i in tensor.data.shape:
newshape *= i
tensor.data.resize(newshape)
block_size = 210
n_blocks = len(tensor.data) // block_size
for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)):
block_offs = block_num * block_size
# Byte-Swap f16 sized field
delta = tensor.data[block_offs + 208:block_offs + 210].view(dtype=np.uint16)
delta.byteswap(inplace=True)
# Byte-Swap
if block_num % 100000 == 0:
inner_pbar.set_description(f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]")
# restore old shape in case it's ever used
tensor.data.resize(oldshape)
else:
# Handle other tensor types
tensor.data.byteswap(inplace=True)

View File

@ -191,6 +191,7 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.q_proj", # llama4
"model.transformer.blocks.{bid}.q_proj", # llada
"layers.{bid}.self_attn.q_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.q_proj", # nemotron-h
),
# Attention key
@ -209,6 +210,7 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.k_proj", # llama4
"model.transformer.blocks.{bid}.k_proj", # llada
"layers.{bid}.self_attn.k_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.k_proj", # nemotron-h
),
# Attention value
@ -226,6 +228,7 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.v_proj", # llama4
"model.transformer.blocks.{bid}.v_proj", # llada
"layers.{bid}.self_attn.v_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.v_proj", # nemotron-h
),
# Attention output
@ -260,6 +263,7 @@ class TensorNameMap:
"transformer_encoder.{bid}.wo", # neobert
"model.transformer.blocks.{bid}.attn_out", # llada
"layers.{bid}.self_attn.o_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.o_proj", # nemotron-h
),
# Attention output norm
@ -387,6 +391,7 @@ class TensorNameMap:
"model.layers.{bid}.block_sparse_moe.up", # smallthinker
"model.transformer.blocks.{bid}.up_proj", # llada
"layers.{bid}.mlp.up_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.up_proj", # nemotron-h
),
MODEL_TENSOR.FFN_UP_EXP: (
@ -427,7 +432,6 @@ class TensorNameMap:
"model.layers.{bid}.residual_mlp.w1", # arctic
"transformer.h.{bid}.mlp.c_fc_0", # exaone
"model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba granite-hybrid
"model.layers.{bid}.block_sparse_moe.gate", # smallthinker
"model.transformer.blocks.{bid}.ff_proj", # llada
"layers.{bid}.mlp.gate_proj", # qwen3-embedding
),
@ -481,6 +485,7 @@ class TensorNameMap:
"model.layers.{bid}.block_sparse_moe.down", # smallthinker
"model.transformer.blocks.{bid}.ff_out", # llada
"layers.{bid}.mlp.down_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.down_proj", # nemotron-h
),
MODEL_TENSOR.FFN_DOWN_EXP: (
@ -1123,6 +1128,7 @@ class TensorNameMap:
"vision_encoder.patch_conv", # pixtral
"vision_model.patch_embedding.linear", # llama 4
"visual.patch_embed.proj", # qwen2vl
"vision_tower.patch_embed.proj", # kimi-vl
),
MODEL_TENSOR.V_ENC_EMBD_POS: (
@ -1131,6 +1137,7 @@ class TensorNameMap:
"vpm.embeddings.position_embedding",
"model.vision_model.embeddings.position_embedding", # SmolVLM
"vision_model.positional_embedding_vlm", # llama 4
"vision_tower.patch_embed.pos_emb", # kimi-vl
),
MODEL_TENSOR.V_ENC_ATTN_Q: (
@ -1142,6 +1149,7 @@ class TensorNameMap:
"vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral
"visual.blocks.{bid}.attn.q", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated
),
MODEL_TENSOR.V_ENC_ATTN_Q_NORM: (
@ -1158,6 +1166,7 @@ class TensorNameMap:
"vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral
"visual.blocks.{bid}.attn.k", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated
),
MODEL_TENSOR.V_ENC_ATTN_K_NORM: (
@ -1174,6 +1183,7 @@ class TensorNameMap:
"vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral
"visual.blocks.{bid}.attn.v", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated
),
MODEL_TENSOR.V_ENC_INPUT_NORM: (
@ -1186,6 +1196,7 @@ class TensorNameMap:
"vision_encoder.transformer.layers.{bid}.attention_norm", # pixtral
"vision_model.model.layers.{bid}.input_layernorm", # llama4
"visual.blocks.{bid}.norm1", # qwen2vl
"vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
),
MODEL_TENSOR.V_ENC_ATTN_O: (
@ -1198,6 +1209,7 @@ class TensorNameMap:
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.attention.wo", # pixtral
"visual.blocks.{bid}.attn.proj", # qwen2vl
"vision_tower.encoder.blocks.{bid}.wo", # kimi-vl
),
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
@ -1210,6 +1222,7 @@ class TensorNameMap:
"vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.ffn_norm", # pixtral
"visual.blocks.{bid}.norm2", # qwen2vl
"vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1)
),
MODEL_TENSOR.V_ENC_FFN_UP: (
@ -1222,6 +1235,7 @@ class TensorNameMap:
"vision_model.model.layers.{bid}.mlp.fc1", # llama4
"visual.blocks.{bid}.mlp.fc1", # qwen2vl
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
"vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
),
MODEL_TENSOR.V_ENC_FFN_GATE: (
@ -1240,6 +1254,7 @@ class TensorNameMap:
"vision_model.model.layers.{bid}.mlp.fc2", # llama4
"visual.blocks.{bid}.mlp.fc2", # qwen2vl
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
"vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
),
MODEL_TENSOR.V_LAYER_SCALE_1: (
@ -1264,6 +1279,7 @@ class TensorNameMap:
"model.vision_model.post_layernorm", # SmolVLM
"vision_model.layernorm_post", # llama4
"visual.merger.ln_q", # qwen2vl
"vision_tower.encoder.final_layernorm", # kimi-vl
),
MODEL_TENSOR.V_MM_INP_PROJ: (
@ -1273,6 +1289,7 @@ class TensorNameMap:
MODEL_TENSOR.V_MM_INP_NORM: (
"multi_modal_projector.norm",
"multi_modal_projector.layer_norm",
"multi_modal_projector.pre_norm",
"pre_mm_projector_norm",
),

View File

@ -556,6 +556,24 @@ extern "C" {
struct llama_model * model,
const char * path_lora);
// Functions to access the adapter's GGUF metadata scalar values
// - The functions return the length of the string on success, or -1 on failure
// - The output string is always null-terminated and cleared on failure
// - When retrieving a string, an extra byte must be allocated to account for the null terminator
// - GGUF array values are not supported by these functions
// Get metadata value as a string by key name
LLAMA_API int32_t llama_adapter_meta_val_str(const struct llama_adapter_lora * adapter, const char * key, char * buf, size_t buf_size);
// Get the number of metadata key/value pairs
LLAMA_API int32_t llama_adapter_meta_count(const struct llama_adapter_lora * adapter);
// Get metadata key name by index
LLAMA_API int32_t llama_adapter_meta_key_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
// Get metadata value as a string by index
LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
// Manually free a LoRA adapter
// Note: loaded adapters will be free when the associated model is deleted
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);

View File

@ -0,0 +1,171 @@
{# ---------- special token variables ---------- #}
{%- set bos_token = '<seed:bos>' -%}
{%- set eos_token = '<seed:eos>' -%}
{%- set pad_token = '<seed:pad>' -%}
{%- set toolcall_begin_token = '<seed:tool_call>' -%}
{%- set toolcall_end_token = '</seed:tool_call>' -%}
{%- set think_begin_token = '<seed:think>' -%}
{%- set think_end_token = '</seed:think>' -%}
{%- set budget_begin_token = '<seed:cot_budget_reflect>'-%}
{%- set budget_end_token = '</seed:cot_budget_reflect>'-%}
{# -------------- reflection-interval lookup -------------- #}
{%- if not thinking_budget is defined %}
{%- set thinking_budget = -1 -%}
{%- endif -%}
{%- set budget_reflections_v05 = {
0: 0,
512: 128,
1024: 256,
2048: 512,
4096: 512,
8192: 1024,
16384: 1024
} -%}
{# Find the first gear that is greater than or equal to the thinking_budget. #}
{%- set ns = namespace(interval = None) -%}
{%- for k, v in budget_reflections_v05 | dictsort -%}
{%- if ns.interval is none and thinking_budget <= k -%}
{%- set ns.interval = v -%}
{%- endif -%}
{%- endfor -%}
{# If it exceeds the maximum gear, use the value of the last gear #}
{%- if ns.interval is none -%}
{%- set ns.interval = budget_reflections_v05[16384] -%}
{%- endif -%}
{# ---------- Preprocess the system message ---------- #}
{%- if messages[0]["role"] == "system" %}
{%- set system_message = messages[0]["content"] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}
{# ---------- Ensure tools exist ---------- #}
{%- if not tools is defined or tools is none %}
{%- set tools = [] %}
{%- endif %}
{# tools2doc.jinja #}
{%- macro py_type(t) -%}
{%- if t == "string" -%}str
{%- elif t in ("number", "integer") -%}int
{%- elif t == "boolean" -%}bool
{%- elif t == "array" -%}list
{%- else -%}Any{%- endif -%}
{%- endmacro -%}
{# ---------- Output the system block ---------- #}
{%- if system_message is defined %}
{{ bos_token + "system\n" + system_message }}
{%- else %}
{%- if tools is iterable and tools | length > 0 %}
{{ bos_token + "system\nYou are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query." }}
{%- endif %}
{%- endif %}
{%- if use_json_tooldef is defined and use_json_tooldef %}
{{"Tool List:\nYou are authorized to use the following tools (described in JSON Schema format). Before performing any task, you must decide how to call them based on the descriptions and parameters of these tools."}}
{{ tools | tojson(ensure_ascii=False) }}
{%- else %}
{%- for item in tools if item.type == "function" %}
Function:
def {{ item.function.name }}(
{%- for name, spec in item.function.parameters.properties.items() %}
{{- name }}: {{ py_type(spec.type) }}{% if not loop.last %},{% endif %}
{%- endfor %}):
"""
{{ item.function.description | trim }}
{# ---------- Args ---------- #}
{%- if item.function.parameters.properties %}
Args:
{%- for name, spec in item.function.parameters.properties.items() %}
- {{ name }} ({{ py_type(spec.type) }})
{%- if name in item.function.parameters.required %} [必填]{% else %} [选填]{% endif %}:
{{- " " ~ (spec.description or "") }}
{%- endfor %}
{%- endif %}
{# ---------- Returns ---------- #}
{%- if item.function.returns is defined
and item.function.returns.properties is defined
and item.function.returns.properties %}
Returns:
{%- for name, spec in item.function.returns.properties.items() %}
- {{ name }} ({{ py_type(spec.type) }}):
{{- " " ~ (spec.description or "") }}
{%- endfor %}
{%- endif %}
"""
{%- endfor %}
{%- endif %}
{%- if tools is iterable and tools | length > 0 %}
{{"工具调用请遵循如下格式:\n<seed:tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>value_1</parameter>\n<parameter=example_parameter_2>This is the value for the second parameter\nthat can span\nmultiple lines</parameter>\n</function>\n</seed:tool_call>\n"}}
{%- endif %}
{# End the system block line #}
{%- if system_message is defined or tools is iterable and tools | length > 0 %}
{{ eos_token }}
{%- endif %}
{# ---------- Thinking Budget ---------- #}
{%- if thinking_budget is defined %}
{%- if thinking_budget == 0 %}
{{ bos_token+"system" }}
{{ "You are an intelligent assistant that can answer questions in one step without the need for reasoning and thinking, that is, your thinking budget is 0. Next, please skip the thinking process and directly start answering the user's questions." }}
{{ eos_token }}
{%- elif not thinking_budget == -1 %}
{{ bos_token+"system" }}
{{ "You are an intelligent assistant with reflective ability. In the process of thinking and reasoning, you need to strictly follow the thinking budget, which is "}}{{thinking_budget}}{{". That is, you need to complete your thinking within "}}{{thinking_budget}}{{" tokens and start answering the user's questions. You will reflect on your thinking process every "}}{{ns.interval}}{{" tokens, stating how many tokens have been used and how many are left."}}
{{ eos_token }}
{%- endif %}
{%- endif %}
{# ---------- List the historical messages one by one ---------- #}
{%- for message in loop_messages %}
{%- if message.role == "assistant"
and message.tool_calls is defined
and message.tool_calls is iterable
and message.tool_calls | length > 0 %}
{{ bos_token + message.role }}
{%- if message.reasoning_content is defined and message.reasoning_content is string and message.reasoning_content | trim | length > 0 %}
{{ "\n" + think_begin_token + message.reasoning_content | trim + think_end_token }}
{%- endif %}
{%- if message.content is defined and message.content is string and message.content | trim | length > 0 %}
{{ "\n" + message.content | trim + "\n" }}
{%- endif %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}{% set tool_call = tool_call.function %}{% endif %}
{{ "\n" + toolcall_begin_token + "\n<function=" + tool_call.name + ">\n" }}
{%- if tool_call.arguments is defined %}
{%- for arg_name, arg_value in tool_call.arguments | items %}
{{ "<parameter=" + arg_name + ">" }}
{%- set arg_value = arg_value if arg_value is string else arg_value | string %}
{{ arg_value+"</parameter>\n" }}
{%- endfor %}
{%- endif %}
{{ "</function>\n" + toolcall_end_token }}
{%- endfor %}
{{ eos_token }}
{%- elif message.role in ["user", "system"] %}
{{ bos_token + message.role + "\n" + message.content + eos_token }}
{%- elif message.role == "assistant" %}
{{ bos_token + message.role }}
{%- if message.reasoning_content is defined and message.reasoning_content is string and message.reasoning_content | trim | length > 0 %}
{{ "\n" + think_begin_token + message.reasoning_content | trim + think_end_token }}
{%- endif %}
{%- if message.content is defined and message.content is string and message.content | trim | length > 0 %}
{{ "\n" + message.content | trim + eos_token }}
{%- endif %}
{# Include the tool role #}
{%- else %}
{{ bos_token + message.role + "\n" + message.content + eos_token }}
{%- endif %}
{%- endfor %}
{# ---------- Control the model to start continuation ---------- #}
{%- if add_generation_prompt %}
{{ bos_token+"assistant\n" }}
{%- if thinking_budget == 0 %}
{{ think_begin_token + "\n" + budget_begin_token + "The current thinking budget is 0, so I will directly start answering the question." + budget_end_token + "\n" + think_end_token }}
{%- endif %}
{%- endif %}

View File

@ -25,6 +25,12 @@ fi
# verify at the start that the compare script has all the necessary dependencies installed
./scripts/compare-llama-bench.py --check
if ! command -v sqlite3 >/dev/null 2>&1; then
echo "Error: sqlite3 is not installed or not in PATH"
echo "Please install sqlite3 to use this script"
exit 1
fi
if [ "$tool" = "llama-bench" ]; then
db_file="llama-bench.sqlite"
target="llama-bench"

View File

@ -96,7 +96,7 @@ DEFAULT_HIDE_LLAMA_BENCH = ["model_filename"] # Always hide these properties by
DEFAULT_SHOW_TEST_BACKEND_OPS = ["backend_name", "op_name"] # Always show these properties by default.
DEFAULT_HIDE_TEST_BACKEND_OPS = ["error_message"] # Always hide these properties by default.
GPU_NAME_STRIP = ["NVIDIA GeForce ", "Tesla ", "AMD Radeon "] # Strip prefixes for smaller tables.
GPU_NAME_STRIP = ["NVIDIA GeForce ", "Tesla ", "AMD Radeon ", "AMD Instinct "] # Strip prefixes for smaller tables.
MODEL_SUFFIX_REPLACE = {" - Small": "_S", " - Medium": "_M", " - Large": "_L"}
DESCRIPTION = """Creates tables from llama-bench or test-backend-ops data written to multiple JSON/CSV files, a single JSONL file or SQLite database. Example usage (Linux):

View File

@ -163,13 +163,38 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
// check metadata
{
const gguf_context * gguf_ctx = ctx_gguf.get();
LLAMA_LOG_INFO("%s: Dumping metadata keys/values.\n", __func__);
// get metadata as string
for (int i = 0; i < gguf_get_n_kv(gguf_ctx); i++) {
gguf_type type = gguf_get_kv_type(gguf_ctx, i);
const std::string type_name =
type == GGUF_TYPE_ARRAY
? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(gguf_ctx, i)), gguf_get_arr_n(gguf_ctx, i))
: gguf_type_name(type);
const char * name = gguf_get_key(gguf_ctx, i);
const std::string value = gguf_kv_to_str(gguf_ctx, i);
if (type != GGUF_TYPE_ARRAY) {
adapter.gguf_kv.emplace(name, value);
}
const size_t MAX_VALUE_LEN = 40;
std::string print_value = value.size() > MAX_VALUE_LEN ? format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()) : value;
replace_all(print_value, "\n", "\\n");
LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), print_value.c_str());
}
auto get_kv_str = [&](const std::string & key) -> std::string {
int id = gguf_find_key(ctx_gguf.get(), key.c_str());
return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf.get(), id));
int id = gguf_find_key(gguf_ctx, key.c_str());
return id < 0 ? "" : std::string(gguf_get_val_str(gguf_ctx, id));
};
auto get_kv_f32 = [&](const std::string & key) -> float {
int id = gguf_find_key(ctx_gguf.get(), key.c_str());
return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf.get(), id);
int id = gguf_find_key(gguf_ctx, key.c_str());
return id < 0 ? 0.0f : gguf_get_val_f32(gguf_ctx, id);
};
LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
@ -383,6 +408,45 @@ llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * p
return nullptr;
}
int32_t llama_adapter_meta_val_str(const llama_adapter_lora * adapter, const char * key, char * buf, size_t buf_size) {
const auto & it = adapter->gguf_kv.find(key);
if (it == adapter->gguf_kv.end()) {
if (buf_size > 0) {
buf[0] = '\0';
}
return -1;
}
return snprintf(buf, buf_size, "%s", it->second.c_str());
}
int32_t llama_adapter_meta_count(const llama_adapter_lora * adapter) {
return (int)adapter->gguf_kv.size();
}
int32_t llama_adapter_meta_key_by_index(const llama_adapter_lora * adapter, int i, char * buf, size_t buf_size) {
if (i < 0 || i >= (int)adapter->gguf_kv.size()) {
if (buf_size > 0) {
buf[0] = '\0';
}
return -1;
}
auto it = adapter->gguf_kv.begin();
std::advance(it, i);
return snprintf(buf, buf_size, "%s", it->first.c_str());
}
int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size) {
if (i < 0 || i >= (int)adapter->gguf_kv.size()) {
if (buf_size > 0) {
buf[0] = '\0';
}
return -1;
}
auto it = adapter->gguf_kv.begin();
std::advance(it, i);
return snprintf(buf, buf_size, "%s", it->second.c_str());
}
void llama_adapter_lora_free(llama_adapter_lora * adapter) {
delete adapter;
}

View File

@ -67,6 +67,9 @@ struct llama_adapter_lora {
float alpha;
// gguf metadata
std::unordered_map<std::string, std::string> gguf_kv;
llama_adapter_lora() = default;
~llama_adapter_lora() = default;

View File

@ -22,6 +22,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
{ LLM_ARCH_NEO_BERT, "neo-bert" },
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
{ LLM_ARCH_JINA_BERT_V3, "jina-bert-v3" },
{ LLM_ARCH_BLOOM, "bloom" },
{ LLM_ARCH_STABLELM, "stablelm" },
{ LLM_ARCH_QWEN, "qwen" },
@ -68,6 +69,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_T5ENCODER, "t5encoder" },
{ LLM_ARCH_JAIS, "jais" },
{ LLM_ARCH_NEMOTRON, "nemotron" },
{ LLM_ARCH_NEMOTRON_H, "nemotron_h" },
{ LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_EXAONE4, "exaone4" },
{ LLM_ARCH_RWKV6, "rwkv6" },
@ -234,8 +236,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" },
{ LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" },
{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
{ LLM_KV_ADAPTER_LORA_TASK_NAME, "adapter.lora.task_name" },
{ LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" },
// deprecated
{ LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" },
@ -575,6 +579,20 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_CLS, "cls" },
},
},
{
LLM_ARCH_JINA_BERT_V3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
},
},
{
LLM_ARCH_BLOOM,
{
@ -1533,6 +1551,31 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_NEMOTRON_H,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
// mamba(2) ssm layers
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
// attention layers
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
// dense FFN
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_EXAONE,
{
@ -2338,6 +2381,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
case LLM_ARCH_PLAMO2:
case LLM_ARCH_GRANITE_HYBRID:
case LLM_ARCH_LFM2:
case LLM_ARCH_NEMOTRON_H:
return true;
default:
return false;

View File

@ -26,6 +26,7 @@ enum llm_arch {
LLM_ARCH_NOMIC_BERT_MOE,
LLM_ARCH_NEO_BERT,
LLM_ARCH_JINA_BERT_V2,
LLM_ARCH_JINA_BERT_V3,
LLM_ARCH_BLOOM,
LLM_ARCH_STABLELM,
LLM_ARCH_QWEN,
@ -72,6 +73,7 @@ enum llm_arch {
LLM_ARCH_T5ENCODER,
LLM_ARCH_JAIS,
LLM_ARCH_NEMOTRON,
LLM_ARCH_NEMOTRON_H,
LLM_ARCH_EXAONE,
LLM_ARCH_EXAONE4,
LLM_ARCH_RWKV6,
@ -230,6 +232,8 @@ enum llm_kv {
LLM_KV_ADAPTER_TYPE,
LLM_KV_ADAPTER_LORA_ALPHA,
LLM_KV_ADAPTER_LORA_TASK_NAME,
LLM_KV_ADAPTER_LORA_PROMPT_PREFIX,
LLM_KV_POSNET_EMBEDDING_LENGTH,
LLM_KV_POSNET_BLOCK_COUNT,

View File

@ -102,16 +102,6 @@ llama_context::llama_context(
cparams.op_offload = params.op_offload;
cparams.kv_unified = params.kv_unified;
{
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : supports_set_rows;
if (!supports_set_rows && !cparams.kv_unified) {
LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
cparams.kv_unified = true;
}
}
{
const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable;
@ -280,7 +270,7 @@ llama_context::llama_context(
}
// reserve worst-case graph
if (!hparams.vocab_only && memory) {
if (!hparams.vocab_only) {
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
@ -292,11 +282,13 @@ llama_context::llama_context(
int n_splits_tg = -1;
int n_nodes_tg = -1;
// simulate full KV cache
const auto mctx = memory->init_full();
if (!mctx) {
throw std::runtime_error("failed to initialize KV cache");
llama_memory_context_ptr mctx;
if (memory) {
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
mctx = memory->init_full();
if (!mctx) {
throw std::runtime_error("failed to initialize memory module");
}
}
cross.v_embd.clear();
@ -888,12 +880,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
}
}
if (!supports_set_rows) {
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
ggml_backend_sched_reset(sched.get());
}
// TODO: hacky solution
if (model.arch == LLM_ARCH_T5 && t_embd) {
//cross.t_embd = t_embd;
@ -1056,7 +1042,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module
llama_pos pos_min[LLAMA_MAX_SEQ];
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
pos_min[s] = std::numeric_limits<llama_pos>::max();
@ -1073,7 +1059,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
continue;
}
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
memory->seq_rm(s, pos_min[s], -1);
}
@ -1224,12 +1210,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
// wait for the computation to finish (automatically done when obtaining the model output)
//synchronize();
if (!supports_set_rows) {
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
ggml_backend_sched_reset(sched.get());
}
return 0;
}
@ -1857,7 +1837,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
}
if (memory != nullptr) {
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
memory->state_write(io);
}
@ -1943,7 +1923,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
}
if (memory) {
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
memory->state_read(io);
}

View File

@ -283,10 +283,6 @@ private:
bool has_evaluated_once = false;
// env: LLAMA_SET_ROWS (temporary)
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
bool supports_set_rows = true;
// env: LLAMA_GRAPH_REUSE_DISABLE
bool graph_reuse_disable = false;

View File

@ -314,8 +314,6 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
res &= mctx->get_supports_set_rows(); // TODO: tmp
return res;
}
@ -350,8 +348,6 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp
return res;
}
@ -1376,7 +1372,7 @@ ggml_tensor * llm_graph_context::build_attn(
// [TAG_NO_CACHE_PAD]
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
assert(!ubatch.equal_seqs());
assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
ggml_tensor * q = q_cur;
ggml_tensor * k = k_cur;

View File

@ -197,18 +197,6 @@ llama_kv_cache::llama_kv_cache(
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : supports_set_rows;
if (!supports_set_rows) {
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
GGML_ASSERT(unified && "cannot use non-unified KV cache without ggml_set_rows() support");
}
if (!supports_set_rows) {
LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
}
}
void llama_kv_cache::clear(bool data) {
@ -551,11 +539,8 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_
bool success = true;
for (const auto & ubatch : ubatches) {
// non-continuous slots require support for ggml_set_rows()
const bool cont = supports_set_rows ? false : true;
// only find a suitable slot for the ubatch. don't modify the cells yet
const auto sinfo_new = find_slot(ubatch, cont);
const auto sinfo_new = find_slot(ubatch, false);
if (sinfo_new.empty()) {
success = false;
break;
@ -771,8 +756,8 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
}
res.s0 = std::min<llama_seq_id>(res.s0, seq_to_stream[seq_id]);
res.s1 = std::max<llama_seq_id>(res.s1, seq_to_stream[seq_id]);
res.s0 = std::min<uint32_t>(res.s0, seq_to_stream[seq_id]);
res.s1 = std::max<uint32_t>(res.s1, seq_to_stream[seq_id]);
res.strm[s] = seq_to_stream[seq_id];
res.idxs[s].reserve(n_tokens);
@ -964,11 +949,11 @@ bool llama_kv_cache::get_has_shift() const {
return result;
}
uint32_t llama_kv_cache::get_n_kv() const {
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
uint32_t result = 0;
for (uint32_t s = 0; s < n_stream; ++s) {
const auto & cells = v_cells[s];
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
const auto & cells = v_cells[sinfo.strm[s]];
result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
}
@ -976,10 +961,6 @@ uint32_t llama_kv_cache::get_n_kv() const {
return result;
}
bool llama_kv_cache::get_supports_set_rows() const {
return supports_set_rows;
}
ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il);
@ -1017,52 +998,42 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
// note: v->nb[1] <= v->nb[2]
return ggml_view_4d(ctx, v,
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
}
// note: v->nb[1] > v->nb[2]
return ggml_view_4d(ctx, v,
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
ggml_row_size(v->type, kv_size), // v->nb[2]
ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
ggml_row_size(v->type, kv_size), // v->nb[2]
ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
}
ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
GGML_UNUSED(sinfo);
const int32_t ikv = map_layer_ids.at(il);
auto * k = layers[ikv].k;
const int64_t n_embd_k_gqa = k->ne[0];
const int64_t n_tokens = k_cur->ne[2];
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
if (k_idxs && supports_set_rows) {
if (k->ne[2] > 1) {
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
}
return ggml_set_rows(ctx, k, k_cur, k_idxs);
if (k->ne[2] > 1) {
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
}
// TODO: fallback to old ggml_cpy() method for backwards compatibility
// will be removed when ggml_set_rows() is adopted by all backends
GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
ggml_tensor * k_view = ggml_view_1d(ctx, k,
n_tokens*n_embd_k_gqa,
ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
return ggml_cpy(ctx, k_cur, k_view);
return ggml_set_rows(ctx, k, k_cur, k_idxs);
}
ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
GGML_UNUSED(sinfo);
const int32_t ikv = map_layer_ids.at(il);
auto * v = layers[ikv].v;
@ -1072,48 +1043,25 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
if (v_idxs && supports_set_rows) {
if (!v_trans) {
if (v->ne[2] > 1) {
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
}
return ggml_set_rows(ctx, v, v_cur, v_idxs);
}
// [TAG_V_CACHE_VARIABLE]
if (n_embd_v_gqa < v->ne[0]) {
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
}
// the row becomes a single element
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
}
// TODO: fallback to old ggml_cpy() method for backwards compatibility
// will be removed when ggml_set_rows() is adopted by all backends
GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
ggml_tensor * v_view = nullptr;
if (!v_trans) {
v_view = ggml_view_1d(ctx, v,
n_tokens*n_embd_v_gqa,
ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
} else {
v_cur = ggml_transpose(ctx, v_cur);
if (v->ne[2] > 1) {
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
}
v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
(v->ne[1] )*ggml_element_size(v),
(sinfo.head())*ggml_element_size(v));
return ggml_set_rows(ctx, v, v_cur, v_idxs);
}
return ggml_cpy(ctx, v_cur, v_view);
// [TAG_V_CACHE_VARIABLE]
if (n_embd_v_gqa < v->ne[0]) {
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
}
// the row becomes a single element
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
}
ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
@ -1143,10 +1091,6 @@ ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama
}
void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
if (!supports_set_rows) {
return;
}
const uint32_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
@ -1163,10 +1107,6 @@ void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ub
}
void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
if (!supports_set_rows) {
return;
}
const uint32_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
@ -1985,8 +1925,7 @@ bool llama_kv_cache_context::apply() {
}
kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
n_kv = kv->get_n_kv();
n_kv = kv->get_n_kv(sinfos[i_cur]);
return true;
}
@ -2005,10 +1944,6 @@ uint32_t llama_kv_cache_context::get_n_kv() const {
return n_kv;
}
bool llama_kv_cache_context::get_supports_set_rows() const {
return kv->get_supports_set_rows();
}
ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
}

View File

@ -38,8 +38,8 @@ public:
using idx_vec_t = std::vector<uint32_t>;
// number of streams: ns = s1 - s0 + 1
llama_seq_id s0;
llama_seq_id s1;
uint32_t s0;
uint32_t s1;
std::vector<llama_seq_id> strm; // [ns]
std::vector<idx_vec_t> idxs; // [ns]
@ -139,10 +139,7 @@ public:
// graph_build API
//
uint32_t get_n_kv() const;
// TODO: temporary
bool get_supports_set_rows() const;
uint32_t get_n_kv(const slot_info & sinfo) const;
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
@ -215,10 +212,6 @@ private:
// env: LLAMA_KV_CACHE_DEBUG
int debug = 0;
// env: LLAMA_SET_ROWS (temporary)
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
bool supports_set_rows = true;
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
std::vector<ggml_context_ptr> ctxs;
@ -318,9 +311,6 @@ public:
uint32_t get_n_kv() const;
// TODO: temporary
bool get_supports_set_rows() const;
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;

View File

@ -788,6 +788,7 @@ const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::stri
}
struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list<int64_t> & ne, int flags) {
LLAMA_LOG_DEBUG("%s: loading tensor %s\n", __func__, name.c_str());
const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED));
if (cur == NULL) {

View File

@ -47,6 +47,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_410M: return "410M";
case LLM_TYPE_450M: return "450M";
case LLM_TYPE_475M: return "475M";
case LLM_TYPE_558M: return "558M";
case LLM_TYPE_700M: return "700M";
case LLM_TYPE_770M: return "770M";
case LLM_TYPE_780M: return "780M";
@ -772,6 +773,18 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_JINA_BERT_V3:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
switch (hparams.n_layer) {
case 24:
type = LLM_TYPE_558M; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
{
@ -1557,6 +1570,27 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_NEMOTRON_H:
{
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
// A layer is recurrent IFF the n_head_kv value is set to 0 and
// the n_ff value is set to 0
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
hparams.recurrent_layer_arr[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0);
}
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 56: type = LLM_TYPE_9B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_EXAONE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@ -2631,6 +2665,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
case LLM_ARCH_BERT:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_JINA_BERT_V3:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED);
@ -2666,24 +2701,22 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0);
if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) {
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
} else {
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
if (arch == LLM_ARCH_BERT || arch == LLM_ARCH_NOMIC_BERT_MOE) {
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
} else {
if (arch == LLM_ARCH_NOMIC_BERT) {
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
}
}
@ -4676,6 +4709,75 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
}
} break;
case LLM_ARCH_NEMOTRON_H:
{
// mamba2 Mixer SSM params
// NOTE: int64_t for tensor dimensions
const int64_t d_conv = hparams.ssm_d_conv;
const int64_t d_inner = hparams.ssm_d_inner;
const int64_t d_state = hparams.ssm_d_state;
const int64_t n_ssm_head = hparams.ssm_dt_rank;
const int64_t n_group = hparams.ssm_n_group;
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head;
// embeddings
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
{
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed, duplicated to allow offloading
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
}
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
// all blocks use the attn norm
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
if (hparams.is_recurrent(i)) {
// ssm layers
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0);
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0);
layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED);
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0);
// no "weight" suffix for these
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0);
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0);
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);
// out_proj
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
} else if (hparams.n_ff(i) == 0) {
// attention layers (with optional bias)
const int64_t n_head_i = hparams.n_head(i);
const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i);
const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0);
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED);
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED);
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
} else {
// mlp layers
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0);
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED);
}
}
} break;
case LLM_ARCH_EXAONE:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -5850,7 +5952,8 @@ void llama_model::print_info() const {
arch == LLM_ARCH_JAMBA ||
arch == LLM_ARCH_FALCON_H1 ||
arch == LLM_ARCH_PLAMO2 ||
arch == LLM_ARCH_GRANITE_HYBRID) {
arch == LLM_ARCH_GRANITE_HYBRID ||
arch == LLM_ARCH_NEMOTRON_H) {
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
@ -7461,7 +7564,7 @@ struct llm_build_bert : public llm_graph_context {
}
// RoPE
if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE || model.arch == LLM_ARCH_JINA_BERT_V3) {
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@ -7520,7 +7623,7 @@ struct llm_build_bert : public llm_graph_context {
0.0f,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
cb(cur, "ffn_moe_out", il);
} else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
} else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE || model.arch == LLM_ARCH_JINA_BERT_V3) {
cur = build_ffn(cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
NULL, NULL, NULL,
@ -14117,6 +14220,138 @@ struct llm_build_nemotron : public llm_graph_context {
}
};
struct llm_build_nemotron_h : public llm_graph_context_mamba {
llm_build_nemotron_h(
const llama_model & model,
const llm_graph_params & params) :
llm_graph_context_mamba(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
auto * inp = build_inp_mem_hybrid();
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
if (hparams.is_recurrent(il)) {
// ssm layer //
cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il);
} else if (hparams.n_ff(il) == 0) {
// attention layer //
cur = build_attention_layer(cur, inp->get_attn(), model, n_embd_head, il);
} else {
cur = build_ffn_layer(cur, model, il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
// add residual
cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "block_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
ggml_tensor * build_attention_layer(
ggml_tensor * cur,
llm_graph_input_attn_kv * inp_attn,
const llama_model & model,
const int64_t n_embd_head,
const int il) {
// compute Q and K and (optionally) RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
cur = build_attn(inp_attn,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
return cur;
}
ggml_tensor * build_ffn_layer(
ggml_tensor * cur,
const llama_model & model,
const int il) {
cur = build_ffn(cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
NULL, NULL, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
NULL,
LLM_FFN_RELU_SQR, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
return cur;
}
};
struct llm_build_exaone : public llm_graph_context {
llm_build_exaone(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
@ -18241,6 +18476,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
// switch statement
case LLM_ARCH_BERT:
case LLM_ARCH_JINA_BERT_V2:
case LLM_ARCH_JINA_BERT_V3:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_NEO_BERT:
@ -18264,6 +18500,23 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.n_seq_max,
nullptr);
} else if (llm_arch_is_hybrid(arch)) {
// The main difference between hybrid architectures is the
// layer filters, so pick the right one here
llama_memory_hybrid::layer_filter_cb filter_attn = nullptr;
llama_memory_hybrid::layer_filter_cb filter_recr = nullptr;
if (arch == LLM_ARCH_FALCON_H1) {
filter_attn = [&](int32_t) { return true; };
filter_recr = [&](int32_t) { return true; };
} else if (arch == LLM_ARCH_NEMOTRON_H) {
filter_attn = [&](int32_t il) {
return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0;
};
filter_recr = [&](int32_t il) {
return hparams.is_recurrent(il) && hparams.n_ff(il) == 0;
};
}
const auto padding = llama_kv_cache::get_padding(cparams);
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
@ -18283,8 +18536,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* n_seq_max */ cparams.n_seq_max,
/* offload */ cparams.offload_kqv,
/* unified */ cparams.kv_unified,
/* filter_attn */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr,
/* filter_recr */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr);
/* filter_attn */ std::move(filter_attn),
/* filter_recr */ std::move(filter_recr));
} else {
const auto padding = llama_kv_cache::get_padding(cparams);
@ -18395,6 +18648,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
} break;
case LLM_ARCH_BERT:
case LLM_ARCH_JINA_BERT_V2:
case LLM_ARCH_JINA_BERT_V3:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
{
@ -18611,6 +18865,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_nemotron>(*this, params);
} break;
case LLM_ARCH_NEMOTRON_H:
{
llm = std::make_unique<llm_build_nemotron_h>(*this, params);
} break;
case LLM_ARCH_EXAONE:
{
llm = std::make_unique<llm_build_exaone>(*this, params);
@ -18846,6 +19104,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
case LLM_ARCH_WAVTOKENIZER_DEC:
case LLM_ARCH_NEMOTRON_H:
return LLAMA_ROPE_TYPE_NONE;
// use what we call a normal RoPE, operating on pairs of consecutive head values
@ -18885,6 +19144,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_GROK:
case LLM_ARCH_DBRX:
case LLM_ARCH_BERT:
case LLM_ARCH_JINA_BERT_V3:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_STABLELM:

View File

@ -40,6 +40,7 @@ enum llm_type {
LLM_TYPE_450M,
LLM_TYPE_475M,
LLM_TYPE_537M,
LLM_TYPE_558M,
LLM_TYPE_700M,
LLM_TYPE_770M,
LLM_TYPE_780M,

View File

@ -2470,7 +2470,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
// set attributes by model/tokenizer/architecture name
if (false
|| _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})
|| _contains_any(general_arch, {"nomic-bert-moe"})
|| _contains_any(general_arch, {"nomic-bert-moe", "jina-bert-v3"})
) {
if (token_to_id.count("<mask>") == 0) {
LLAMA_LOG_WARN("%s: Mask token is missing in vocab, please reconvert model!\n", __func__);

View File

@ -2209,6 +2209,26 @@ struct test_count_equal : public test_case {
double max_nmse_err() override {
return 0.0;
}
void initialize_tensors(ggml_context * ctx) override {
std::random_device rd;
std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_F32) {
// initialize with unique values to avoid ties
for (int64_t r = 0; r < ggml_nrows(t); r++) {
std::vector<float> data(t->ne[0]);
for (int i = 0; i < t->ne[0]; i++) {
data[i] = i;
}
std::shuffle(data.begin(), data.end(), rng);
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
}
} else {
init_tensor_uniform(t);
}
}
}
};
// GGML_OP_REPEAT
@ -2769,6 +2789,49 @@ struct test_norm : public test_case {
}
};
// GGML_OP_NORM + GGML_OP_MUL + GGML_OP_ADD
struct test_norm_mul_add : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
float eps;
const bool broadcast;
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
return "NORM_MUL_ADD";
}
bool run_whole_graph() override { return true; }
std::string vars() override {
return VARS_TO_STR4(type, ne, eps, broadcast);
}
test_norm_mul_add(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {128, 2, 1, 1},
float eps = 1e-5f,
bool broadcast = false)
: type(type), ne(ne), eps(eps), broadcast(broadcast) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
std::array<int64_t, 4> broadcast_dims = {ne[0], ne[1] * 2, ne[2] * 2, ne[3] * 2};
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, broadcast ? broadcast_dims.data() : ne.data());
ggml_tensor * w = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a); ggml_set_param(w); ggml_set_param(b);
ggml_set_name(a, "a"); ggml_set_name(w, "w"); ggml_set_name(b, "b");
// Use a, w and b early to avoid OP_NONE in graph
a = ggml_add(ctx, ggml_add(ctx, a, w), b);
ggml_tensor * n = ggml_norm(ctx, a, eps);
ggml_tensor * m = ggml_mul(ctx, n, w);
ggml_tensor * out = ggml_add(ctx, m, b);
ggml_set_name(out, "out");
return out;
}
};
// GGML_OP_RMS_NORM
struct test_rms_norm : public test_case {
const ggml_type type;
@ -4455,6 +4518,44 @@ struct test_group_norm : public test_case {
}
};
// GGML_OP_GROUP_NORM + GGML_OP_MUL + GGML_OP_ADD
struct test_group_norm_mul_add : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
int num_groups;
float eps;
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
return "GROUP_NORM_MUL_ADD";
}
bool run_whole_graph() override { return true; }
std::string vars() override {
return VARS_TO_STR4(type, ne, num_groups, eps);
}
test_group_norm_mul_add(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {128, 1, 1, 1},
int num_groups = 4,
float eps = 1e-5f)
: type(type), ne(ne), num_groups(num_groups), eps(eps) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * w = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a); ggml_set_param(w); ggml_set_param(b);
ggml_set_name(a, "a"); ggml_set_name(w, "w"); ggml_set_name(b, "b");
ggml_tensor * n = ggml_group_norm(ctx, a, num_groups, eps);
ggml_tensor * m = ggml_mul(ctx, n, w);
ggml_tensor * out = ggml_add(ctx, m, b);
ggml_set_name(out, "out");
return out;
}
};
// GGML_OP_L2_NORM
struct test_l2_norm : public test_case {
const ggml_type type;
@ -5845,6 +5946,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
}
for (uint32_t n : {1, 511, 1025, 8192, 33*512}) {
for (bool multi_add : {false, true}) {
@ -5997,6 +6100,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
// test large experts*tokens
for (bool b : {false, true}) {
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 32, 1024, 16));
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 2, 2, b, 32, 8192, 64));
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 50, 200, 64));
}
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
@ -6231,6 +6336,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {64, 64, 320, 1}));
test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {9, 9, 1280, 1}));
test_cases.emplace_back(new test_acc());
test_cases.emplace_back(new test_pad());
test_cases.emplace_back(new test_pad_reflect_1d());
@ -6378,6 +6485,24 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
}
}
// qwen3-30b-a3b
for (int bs : {1, 4, 8, 512}) {
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048, 1));
}
}
}
// gpt-oss-20b
for (int bs : {1, 4, 8, 512}) {
for (ggml_type type_a : {GGML_TYPE_MXFP4}) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1));
}
}
}
for (int K : {3, 5}) {
for (int IC : {256, 2560}) {
for (int IW_IH : {32, 64, 256}) {

View File

@ -1621,6 +1621,140 @@ static void test_template_output_parsers() {
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
}));
}
{
// Seed-OSS format tests
auto tmpls = read_templates("models/templates/ByteDance-Seed-OSS.jinja");
std::vector<std::string> end_tokens{ "<seed:eos>" };
assert_equals(COMMON_CHAT_FORMAT_SEED_OSS, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
assert_equals(COMMON_CHAT_FORMAT_SEED_OSS, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
// Test simple reasoning content
assert_msg_equals(
simple_assist_msg("Hello, world!", "I'm thinking about the answer"),
common_chat_parse(
"<seed:think>I'm thinking about the answer</seed:think>Hello, world!",
/* is_partial= */ false,
{
/* .format = */ COMMON_CHAT_FORMAT_SEED_OSS,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
}));
// Test budget reflection tags
common_chat_msg msg_budget_reflect;
msg_budget_reflect.role = "assistant";
msg_budget_reflect.content = "<seed:cot_budget_reflect>Token usage: 45/1000\nI should continue thinking to find the best solution.</seed:cot_budget_reflect>I need to calculate this step by step.";
msg_budget_reflect.reasoning_content = "Token usage: 45/1000\nI should continue thinking to find the best solution.";
assert_msg_equals(
msg_budget_reflect,
common_chat_parse(
"<seed:think>Token usage: 45/1000\nI should continue thinking to find the best solution.</seed:think>"
"<seed:cot_budget_reflect>Token usage: 45/1000\nI should continue thinking to find the best solution.</seed:cot_budget_reflect>"
"I need to calculate this step by step.",
/* is_partial= */ false,
{
/* .format = */ COMMON_CHAT_FORMAT_SEED_OSS,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
}));
// Test tool calls with Seed-OSS format
common_chat_msg msg_tool_call;
msg_tool_call.role = "assistant";
msg_tool_call.tool_calls.push_back({"calculate_sum", "{\"numbers\": [1, 2, 3]}", ""});
assert_msg_equals(
msg_tool_call,
common_chat_parse(
"<seed:tool_call>\n"
"<function=calculate_sum>\n"
"<parameter=numbers>[1, 2, 3]</parameter>\n"
"</function>\n"
"</seed:tool_call>",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_SEED_OSS}));
// Test reasoning + tool call combination
common_chat_msg msg_reasoning_tool;
msg_reasoning_tool.role = "assistant";
msg_reasoning_tool.content = "";
msg_reasoning_tool.reasoning_content = "I need to calculate the sum of these numbers";
msg_reasoning_tool.tool_calls.push_back({"calculate_sum", "{\"numbers\": [1, 2, 3]}", ""});
assert_msg_equals(
msg_reasoning_tool,
common_chat_parse(
"<seed:think>I need to calculate the sum of these numbers</seed:think>"
"<seed:tool_call>\n"
"<function=calculate_sum>\n"
"<parameter=numbers>[1, 2, 3]</parameter>\n"
"</function>\n"
"</seed:tool_call>",
/* is_partial= */ false,
{
/* .format = */ COMMON_CHAT_FORMAT_SEED_OSS,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
}));
// Test deltas: the number of tool calls in partial parses should never decrease
std::string tool_msg = "<seed:tool_call>\n"
"<function=fun>\n"
"<parameter=smth>[1, 2, 3]</parameter>\n"
"</function>";
std::size_t previousToolCalls = 0;
for (std::size_t i = std::string("<seed:tool_call>").length(); i < tool_msg.length() - 1; i++) {
auto partial = tool_msg.substr(0, i);
auto partial_res = common_chat_parse(partial, true, { COMMON_CHAT_FORMAT_SEED_OSS, COMMON_REASONING_FORMAT_DEEPSEEK });
if (partial_res.tool_calls.size() < previousToolCalls) {
throw std::runtime_error("Tool call size decreased on partial: " + partial + " from " + std::to_string(previousToolCalls) + " to " + std::to_string(partial_res.tool_calls.size()));
}
previousToolCalls = partial_res.tool_calls.size();
}
// Test multiple parameters in tool call
common_chat_msg msg_multi_param;
msg_multi_param.role = "assistant";
msg_multi_param.tool_calls.push_back({"process_data", "{\"input\": \"test\", \"format\": \"json\"}", ""});
assert_msg_equals(
msg_multi_param,
common_chat_parse(
"<seed:tool_call>\n"
"<function=process_data>\n"
"<parameter=input>test</parameter>\n"
"<parameter=format>json</parameter>\n"
"</function>\n"
"</seed:tool_call>",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_SEED_OSS}));
// Test partial parsing for incomplete tool call - don't actually add the call until parsing parameters is done
assert_msg_equals(
simple_assist_msg("", ""),
common_chat_parse(
"<seed:tool_call>\n"
"<function=calculate_sum>\n"
"<parameter=numbers>[1,\n",
/* is_partial= */ true,
{COMMON_CHAT_FORMAT_SEED_OSS}));
// Test incomplete reasoning tag
assert_msg_equals(
simple_assist_msg("", "I was thinking"),
common_chat_parse(
"<seed:think>I was thinking",
/* is_partial= */ true,
{
/* .format = */ COMMON_CHAT_FORMAT_SEED_OSS,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
}));
// Test content without reasoning
assert_msg_equals(
simple_assist_msg("This is a simple response without reasoning."),
common_chat_parse(
"This is a simple response without reasoning.",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_SEED_OSS}));
}
}
static void test_msg_diffs_compute() {

View File

@ -3,7 +3,6 @@
#include "ggml.h"
#include "ggml-alloc.h"
#include "ggml-backend.h"
#include "ggml-cpu.h"
#include "ggml-opt.h"
#include <cmath>
@ -899,6 +898,7 @@ static std::pair<int, int> test_backend(
int main(void) {
ggml_log_set(nullptr, nullptr);
ggml_backend_load_all();
const size_t dev_count = ggml_backend_dev_count();
printf("Testing %zu devices\n\n", dev_count);
size_t n_ok = 0;
@ -911,11 +911,12 @@ int main(void) {
ggml_backend_t backend = ggml_backend_dev_init(devs[i], NULL);
GGML_ASSERT(backend != NULL);
#ifndef _MSC_VER
if (ggml_backend_is_cpu(backend)) {
ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency() / 2);
auto * reg = ggml_backend_dev_backend_reg(devs[i]);
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
if (ggml_backend_set_n_threads_fn) {
ggml_backend_set_n_threads_fn(backend, std::thread::hardware_concurrency() / 2);
}
#endif
backends.push_back(backend);
}

View File

@ -124,7 +124,7 @@ int main(int argc, char ** argv) {
const int tg = n_tg[i_tg];
const int pl = n_pl[i_pl];
const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
const int n_ctx_req = is_pp_shared ? (params.kv_unified ? pp : pl*pp) + pl*tg : pl*(pp + tg);
if (n_ctx_req > n_kv_max) {
continue;
@ -147,13 +147,24 @@ int main(int argc, char ** argv) {
return 1;
}
const auto t_pp_end = ggml_time_us();
if (is_pp_shared) {
for (int32_t i = 1; i < pl; ++i) {
llama_memory_seq_cp(mem, 0, i, -1, -1);
}
}
const auto t_pp_end = ggml_time_us();
if (!params.kv_unified) {
// run one dummy token to apply the memory copy
common_batch_clear(batch);
common_batch_add(batch, get_token_rand(), pp + 0, { 0 }, true);
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
}
llama_memory_seq_rm(mem, 0, pp, -1);
}
}
const auto t_tg_start = ggml_time_us();
@ -180,7 +191,7 @@ int main(int argc, char ** argv) {
const float speed_pp = is_pp_shared ? pp / t_pp : pl*pp / t_pp;
const float speed_tg = pl*tg / t_tg;
const float speed = n_kv / t;
const float speed = ((is_pp_shared ? pp : pl*pp) + pl*tg) / t;
if(params.batched_bench_output_jsonl) {
LOG(

View File

@ -587,12 +587,12 @@ int main(int argc, char ** argv) {
if (n_past + (int) embd.size() >= n_ctx) {
if (!params.ctx_shift){
LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
LOG_WRN("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
break;
}
if (params.n_predict == -2) {
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
LOG_WRN("\n\n%s: context full and n_predict == %d => stopping\n", __func__, params.n_predict);
break;
}

View File

@ -55,6 +55,8 @@ add_executable(llama-qwen2vl-cli deprecation-warning.cpp)
set(TARGET llama-mtmd-cli)
add_executable (${TARGET} mtmd-cli.cpp)
set_target_properties (${TARGET} PROPERTIES OUTPUT_NAME llama-mtmd-cli)
install (TARGETS ${TARGET} RUNTIME)
if(NOT CMAKE_SYSTEM_NAME STREQUAL "iOS")
install(TARGETS ${TARGET} RUNTIME)
endif()
target_link_libraries (${TARGET} PRIVATE common mtmd Threads::Threads)
target_compile_features(${TARGET} PRIVATE cxx_std_17)

View File

@ -135,6 +135,7 @@ enum projector_type {
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
PROJECTOR_TYPE_VOXTRAL,
PROJECTOR_TYPE_LFM2,
PROJECTOR_TYPE_KIMIVL,
PROJECTOR_TYPE_UNKNOWN,
};
@ -156,6 +157,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
{ PROJECTOR_TYPE_LFM2, "lfm2"},
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
};
static projector_type clip_projector_type_from_string(const std::string & str) {

View File

@ -526,57 +526,16 @@ struct clip_graph {
cur);
} else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) {
// pixel_shuffle
// https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
const int scale_factor = model.hparams.proj_scale_factor;
const int n_embd = cur->ne[0];
const int seq = cur->ne[1];
const int bsz = 1; // batch size, always 1 for now since we don't support batching
const int height = std::sqrt(seq);
const int width = std::sqrt(seq);
GGML_ASSERT(scale_factor != 0);
cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
cur = ggml_cont_4d(ctx0, cur,
n_embd * scale_factor * scale_factor,
height / scale_factor,
width / scale_factor,
bsz);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
cur = ggml_cont_3d(ctx0, cur,
n_embd * scale_factor * scale_factor,
seq / (scale_factor * scale_factor),
bsz);
cur = build_patch_merge_permute(cur, scale_factor);
cur = ggml_mul_mat(ctx0, model.projection, cur);
} else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
// pixel unshuffle block
const int scale_factor = model.hparams.proj_scale_factor;
GGML_ASSERT(scale_factor > 1);
const int n_embd = cur->ne[0];
int width = img.nx / patch_size;
int height = img.ny / patch_size;
// pad width and height to factor
const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width;
const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height;
cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height);
if (pad_width || pad_height) {
cur = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0);
width += pad_width;
height += pad_height;
}
// unshuffle h
cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
// unshuffle w
cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
cur = build_patch_merge_permute(cur, scale_factor);
// projection
cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
@ -1086,7 +1045,7 @@ struct clip_graph {
n_patches_x / scale_factor,
n_patches_y / scale_factor,
bsz);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
//cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
// flatten to 2D
cur = ggml_cont_2d(ctx0, cur,
n_embd * scale_factor * scale_factor,
@ -1113,6 +1072,67 @@ struct clip_graph {
return gf;
}
ggml_cgraph * build_kimivl() {
// 2D input positions
ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
ggml_set_name(pos_h, "pos_h");
ggml_set_input(pos_h);
ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
ggml_set_name(pos_w, "pos_w");
ggml_set_input(pos_w);
ggml_tensor * learned_pos_embd = resize_position_embeddings();
// build ViT with 2D position embeddings
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
// first half is X axis and second half is Y axis
return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
};
ggml_tensor * inp = build_inp();
ggml_tensor * cur = build_vit(
inp, n_patches,
NORM_TYPE_NORMAL,
hparams.ffn_op,
learned_pos_embd,
add_pos);
cb(cur, "vit_out", -1);
{
// patch_merger
const int scale_factor = model.hparams.proj_scale_factor;
cur = build_patch_merge_permute(cur, scale_factor);
// projection norm
int proj_inp_dim = cur->ne[0];
cur = ggml_view_2d(ctx0, cur,
n_embd, cur->ne[1] * scale_factor * scale_factor,
ggml_row_size(cur->type, n_embd), 0);
cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
cur = ggml_view_2d(ctx0, cur,
proj_inp_dim, cur->ne[1] / scale_factor / scale_factor,
ggml_row_size(cur->type, proj_inp_dim), 0);
cb(cur, "proj_inp_normed", -1);
// projection mlp
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
cur = ggml_add(ctx0, cur, model.mm_1_b);
cur = ggml_gelu(ctx0, cur);
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
cur = ggml_add(ctx0, cur, model.mm_2_b);
cb(cur, "proj_out", -1);
}
// build the graph
ggml_build_forward_expand(gf, cur);
return gf;
}
// this graph is used by llava, granite and glm
// due to having embedding_stack (used by granite), we cannot reuse build_vit
ggml_cgraph * build_llava() {
@ -1611,18 +1631,20 @@ private:
ggml_tensor * pos_embd = model.position_embeddings;
const int height = img.ny / patch_size;
const int width = img.nx / patch_size;
const uint32_t mode = GGML_SCALE_MODE_BILINEAR;
const int n_per_side = (int)std::sqrt(pos_embd->ne[1]);
if (!pos_embd || height * width == pos_embd->ne[1]) {
GGML_ASSERT(pos_embd);
if (height == n_per_side && width == n_per_side) {
return pos_embd;
}
const int n_pos_embd = std::sqrt(pos_embd->ne[1]);
pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_pos_embd, n_pos_embd); // -> (n_embd, n_pos_embd, n_pos_embd)
pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3); // -> (n_pos_embd, n_pos_embd, n_embd)
pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, 1); // -> (width, height, n_embd)
pos_embd = ggml_reshape_2d(ctx0, pos_embd, height * width, n_embd); // -> (height * width, n_embd)
pos_embd = ggml_transpose(ctx0, pos_embd); // -> (n_embd, height * width)
pos_embd = ggml_cont(ctx0, pos_embd);
pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_per_side, n_per_side); // -> (n_embd, n_per_side, n_per_side)
pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3); // -> (n_per_side, n_per_side, n_embd)
pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, mode); // -> (width, height, n_embd)
pos_embd = ggml_permute(ctx0, pos_embd, 1, 2, 0, 3); // -> (n_embd, width, height)
pos_embd = ggml_cont_2d(ctx0, pos_embd, n_embd, width * height); // -> (n_embd, width * height)
return pos_embd;
}
@ -2021,6 +2043,39 @@ private:
return cur;
}
// aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL)
// support dynamic resolution
ggml_tensor * build_patch_merge_permute(ggml_tensor * cur, int scale_factor) {
GGML_ASSERT(scale_factor > 1);
const int n_embd = cur->ne[0];
int width = img.nx / patch_size;
int height = img.ny / patch_size;
// pad width and height to factor
const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width;
const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height;
cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height);
if (pad_width || pad_height) {
cur = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0);
width += pad_width;
height += pad_height;
}
// unshuffle h
cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
// unshuffle w
cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
cb(cur, "pixel_shuffle", -1);
return cur;
}
};
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
@ -2063,6 +2118,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
res = graph.build_whisper_enc();
} break;
case PROJECTOR_TYPE_KIMIVL:
{
res = graph.build_kimivl();
} break;
default:
{
res = graph.build_llava();
@ -2202,6 +2261,8 @@ struct clip_model_loader {
hparams.minicpmv_query_num = 64;
} else if (hparams.minicpmv_version == 5) {
hparams.minicpmv_query_num = 64;
} else if (hparams.minicpmv_version == 6) {
hparams.minicpmv_query_num = 64;
} else {
hparams.minicpmv_query_num = 96;
}
@ -2311,6 +2372,12 @@ struct clip_model_loader {
hparams.image_size = 1024;
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
} break;
case PROJECTOR_TYPE_KIMIVL:
{
hparams.rope_theta = 10000.0f;
hparams.warmup_image_size = hparams.patch_size * 8;
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
} break;
case PROJECTOR_TYPE_GEMMA3:
{
// default value (used by all model sizes in gemma 3 family)
@ -2475,7 +2542,20 @@ struct clip_model_loader {
// some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
// note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
if (layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd) {
bool is_ffn_swapped = (
// only old models need this fix
model.proj_type == PROJECTOR_TYPE_MLP
|| model.proj_type == PROJECTOR_TYPE_MLP_NORM
|| model.proj_type == PROJECTOR_TYPE_LDP
|| model.proj_type == PROJECTOR_TYPE_LDPV2
|| model.proj_type == PROJECTOR_TYPE_QWEN2VL
|| model.proj_type == PROJECTOR_TYPE_QWEN25VL
|| model.proj_type == PROJECTOR_TYPE_GLM_EDGE
|| model.proj_type == PROJECTOR_TYPE_GEMMA3
|| model.proj_type == PROJECTOR_TYPE_IDEFICS3
|| model.proj_type == PROJECTOR_TYPE_MINICPMV
) && layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd;
if (is_ffn_swapped) {
// swap up and down weights
ggml_tensor * tmp = layer.ff_up_w;
layer.ff_up_w = layer.ff_down_w;
@ -2484,6 +2564,9 @@ struct clip_model_loader {
tmp = layer.ff_up_b;
layer.ff_up_b = layer.ff_down_b;
layer.ff_down_b = tmp;
if (il == 0) {
LOG_WRN("%s: ffn up/down are swapped\n", __func__);
}
}
}
@ -2602,6 +2685,7 @@ struct clip_model_loader {
model.projection = get_tensor(TN_MM_PROJECTOR);
} break;
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
{
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B);
@ -3505,7 +3589,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
res_imgs->grid_y = inst.grid_size.height;
return true;
} else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
} else if ( ctx->proj_type() == PROJECTOR_TYPE_LFM2
|| ctx->proj_type() == PROJECTOR_TYPE_KIMIVL
) {
GGML_ASSERT(params.proj_scale_factor);
// smart resize
@ -3685,6 +3771,9 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
} else if (params.minicpmv_version == 5) {
// MiniCPM-V 4.0
n_patches = 64;
} else if (params.minicpmv_version == 6) {
// MiniCPM-V 4.5
n_patches = 64;
} else {
GGML_ABORT("Unknown minicpmv version");
}
@ -3703,12 +3792,21 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
case PROJECTOR_TYPE_IDEFICS3:
case PROJECTOR_TYPE_INTERNVL:
case PROJECTOR_TYPE_LLAMA4:
case PROJECTOR_TYPE_LFM2:
{
// both W and H are divided by proj_scale_factor
// both X and Y are downscaled by the scale factor
int scale_factor = ctx->model.hparams.proj_scale_factor;
n_patches /= (scale_factor * scale_factor);
} break;
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
{
// dynamic size
int scale_factor = ctx->model.hparams.proj_scale_factor;
int out_patch_size = params.patch_size * scale_factor;
int x_patch = CLIP_ALIGN(img->nx, out_patch_size) / out_patch_size;
int y_patch = CLIP_ALIGN(img->ny, out_patch_size) / out_patch_size;
n_patches = x_patch * y_patch;
} break;
case PROJECTOR_TYPE_PIXTRAL:
{
// dynamic size
@ -4091,6 +4189,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
set_input_i32("positions", positions);
} break;
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_KIMIVL:
{
// set the 2D positions
int n_patches_per_col = image_size_width / patch_size;
@ -4245,6 +4344,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
case PROJECTOR_TYPE_QWEN2A:
return ctx->model.mm_fc_w->ne[1];
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
return ctx->model.mm_2_w->ne[1];
default:
GGML_ABORT("Unknown projector type");

View File

@ -607,6 +607,9 @@ else:
elif minicpmv_version == 5:
emb_dim = 2560
block_count = 27
elif minicpmv_version == 6:
emb_dim = 4096
block_count = 27
default_vision_config = {
"hidden_size": 1152,
@ -630,6 +633,10 @@ elif minicpmv_version == 5:
default_vision_config["model_type"] = "siglip_vision_model"
vision_config = SiglipVisionConfig(**default_vision_config)
model = SiglipVisionTransformer(vision_config)
elif minicpmv_version == 6:
default_vision_config["model_type"] = "siglip_vision_model"
vision_config = SiglipVisionConfig(**default_vision_config)
model = SiglipVisionTransformer(vision_config)
processor = None
# if model.attn_pool is not None:

View File

@ -207,7 +207,7 @@ struct mtmd_context {
tok_row_end_trail = false; // no trailing end-of-row token
ov_img_first = true;
} else if (minicpmv_version == 3 || minicpmv_version == 4 || minicpmv_version == 5) {
} else if (minicpmv_version == 3 || minicpmv_version == 4 || minicpmv_version == 5 || minicpmv_version == 6) {
// minicpmv 2.6 format:
// <image> (overview) </image><slice> (slice) </slice><slice> (slice) </slice>\n ...
slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_6;

View File

@ -86,6 +86,7 @@ if [ "$RUN_BIG_TESTS" = true ]; then
add_test_vision "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M"
add_test_vision "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M"
# add_test_vision "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra
add_test_vision "ggml-org/Kimi-VL-A3B-Thinking-2506-GGUF:Q4_K_M"
add_test_audio "ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF:Q4_K_M"
add_test_audio "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M"

View File

@ -62,7 +62,6 @@ The project is under active development, and we are [looking for feedback and co
| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: 1.0)<br/>(env: LLAMA_ARG_YARN_ATTN_FACTOR) |
| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: 1.0)<br/>(env: LLAMA_ARG_YARN_BETA_SLOW) |
| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: 32.0)<br/>(env: LLAMA_ARG_YARN_BETA_FAST) |
| `-dkvc, --dump-kv-cache` | verbose print of the KV cache |
| `-nkvo, --no-kv-offload` | disable KV offload<br/>(env: LLAMA_ARG_NO_KV_OFFLOAD) |
| `-ctk, --cache-type-k TYPE` | KV cache data type for K<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_K) |
| `-ctv, --cache-type-v TYPE` | KV cache data type for V<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_V) |
@ -1143,6 +1142,8 @@ The `response_format` parameter supports both plain JSON output (e.g. `{"type":
`parse_tool_calls`: Whether to parse the generated tool call.
`parallel_tool_calls` : Whether to enable parallel/multiple tool calls (only supported on some models, verification is based on jinja template).
*Examples:*
You can use either Python `openai` library with appropriate checkpoints:

View File

@ -4898,6 +4898,8 @@ int main(int argc, char ** argv) {
{"id", i},
{"path", lora.path},
{"scale", lora.scale},
{"task_name", lora.task_name},
{"prompt_prefix", lora.prompt_prefix},
});
}
res_ok(res, result);

View File

@ -26,10 +26,7 @@ from re import RegexFlag
import wget
DEFAULT_HTTP_TIMEOUT = 12
if "LLAMA_SANITIZE" in os.environ or "GITHUB_ACTION" in os.environ:
DEFAULT_HTTP_TIMEOUT = 30
DEFAULT_HTTP_TIMEOUT = 30
class ServerResponse: