Merge branch 'master' into HEAD
This commit is contained in:
commit
fdac9686f7
|
|
@ -16,7 +16,7 @@ The project differentiates between 3 levels of contributors:
|
||||||
- If you modified a `ggml` operator or added a new one, add the corresponding test cases to `test-backend-ops`
|
- If you modified a `ggml` operator or added a new one, add the corresponding test cases to `test-backend-ops`
|
||||||
- Create separate PRs for each feature or fix. Avoid combining unrelated changes in a single PR
|
- Create separate PRs for each feature or fix. Avoid combining unrelated changes in a single PR
|
||||||
- Consider allowing write access to your branch for faster reviews, as reviewers can push commits directly
|
- Consider allowing write access to your branch for faster reviews, as reviewers can push commits directly
|
||||||
- If your PR becomes stale, don't hesitate to ping the maintainers in the comments
|
- If your PR becomes stale, rebase it on top of latest `master` to get maintainers attention
|
||||||
- Maintainers will rely on your insights and approval when making a final decision to approve and merge a PR
|
- Maintainers will rely on your insights and approval when making a final decision to approve and merge a PR
|
||||||
- Consider adding yourself to [CODEOWNERS](CODEOWNERS) to indicate your availability for reviewing related PRs
|
- Consider adding yourself to [CODEOWNERS](CODEOWNERS) to indicate your availability for reviewing related PRs
|
||||||
- Using AI to generate PRs is permitted. However, you must (1) explicitly disclose how AI was used and (2) conduct a thorough manual review before publishing the PR. Note that trivial tab autocompletions do not require disclosure.
|
- Using AI to generate PRs is permitted. However, you must (1) explicitly disclose how AI was used and (2) conduct a thorough manual review before publishing the PR. Note that trivial tab autocompletions do not require disclosure.
|
||||||
|
|
|
||||||
|
|
@ -1524,6 +1524,79 @@ class TextModel(ModelBase):
|
||||||
special_vocab._set_special_token("bos", 151643)
|
special_vocab._set_special_token("bos", 151643)
|
||||||
special_vocab.add_to_gguf(self.gguf_writer)
|
special_vocab.add_to_gguf(self.gguf_writer)
|
||||||
|
|
||||||
|
def _set_vocab_mistral(self):
|
||||||
|
if not _mistral_common_installed:
|
||||||
|
raise ImportError(_mistral_import_error_msg)
|
||||||
|
|
||||||
|
vocab = MistralVocab(self.dir_model)
|
||||||
|
logger.info(
|
||||||
|
f"Converting tokenizer {vocab.tokenizer_type} of size {vocab.vocab_size}."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gguf_writer.add_tokenizer_model(vocab.gguf_tokenizer_model)
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
scores = []
|
||||||
|
toktypes = []
|
||||||
|
|
||||||
|
for text, score, toktype in vocab.all_tokens():
|
||||||
|
tokens.append(text)
|
||||||
|
scores.append(score)
|
||||||
|
toktypes.append(toktype)
|
||||||
|
|
||||||
|
assert len(tokens) == vocab.vocab_size, (
|
||||||
|
f"token count ({len(tokens)}) != vocab size ({vocab.vocab_size})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if vocab.tokenizer_type == MistralTokenizerType.tekken:
|
||||||
|
self.gguf_writer.add_tokenizer_pre("tekken")
|
||||||
|
self.gguf_writer.add_token_merges(
|
||||||
|
vocab.extract_vocab_merges_from_model()
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Setting bos, eos, unk and pad token IDs to {vocab.bos_id}, {vocab.eos_id}, {vocab.unk_id}, {vocab.pad_id}."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gguf_writer.add_bos_token_id(vocab.bos_id)
|
||||||
|
self.gguf_writer.add_eos_token_id(vocab.eos_id)
|
||||||
|
self.gguf_writer.add_unk_token_id(vocab.unk_id)
|
||||||
|
self.gguf_writer.add_pad_token_id(vocab.pad_id)
|
||||||
|
|
||||||
|
self.gguf_writer.add_token_list(tokens)
|
||||||
|
self.gguf_writer.add_token_scores(scores)
|
||||||
|
self.gguf_writer.add_token_types(toktypes)
|
||||||
|
self.gguf_writer.add_vocab_size(vocab.vocab_size)
|
||||||
|
|
||||||
|
self.gguf_writer.add_add_bos_token(True)
|
||||||
|
self.gguf_writer.add_add_eos_token(False)
|
||||||
|
|
||||||
|
local_template_file_path = self.dir_model / "chat_template.jinja"
|
||||||
|
|
||||||
|
if self.is_mistral_format and local_template_file_path.is_file():
|
||||||
|
# Ministral-3 and other new Mistral models come with chat templates.
|
||||||
|
# ref: https://huggingface.co/mistralai/Ministral-3-14B-Instruct-2512/tree/main
|
||||||
|
logger.info("Using an existing Mistral local chat template.")
|
||||||
|
|
||||||
|
with open(local_template_file_path, "r", encoding="utf-8") as f:
|
||||||
|
template = f.read()
|
||||||
|
elif not self.is_mistral_format or not self.disable_mistral_community_chat_template:
|
||||||
|
template_dir = Path(__file__).parent / "models/templates/"
|
||||||
|
|
||||||
|
# Log only for Mistral format that the official tokenization and detokenization is via `mistral-common`.
|
||||||
|
if self.is_mistral_format:
|
||||||
|
logger.info(
|
||||||
|
"Using a Mistral community chat template. These templates can be subject to errors in early days or weeks after a release. "
|
||||||
|
"Mistral recommends to use `mistral-common` to perform tokenization and detokenization."
|
||||||
|
)
|
||||||
|
template = MistralModel.get_community_chat_template(vocab, template_dir, self.is_mistral_format)
|
||||||
|
else:
|
||||||
|
logger.info("Not using a Mistral local or community chat template. Ensure to perform the tokenization and detokenization via `mistral-common`.")
|
||||||
|
template = None
|
||||||
|
|
||||||
|
if template is not None:
|
||||||
|
self.gguf_writer.add_chat_template(template)
|
||||||
|
|
||||||
|
|
||||||
class MmprojModel(ModelBase):
|
class MmprojModel(ModelBase):
|
||||||
model_type = ModelType.MMPROJ
|
model_type = ModelType.MMPROJ
|
||||||
|
|
@ -2294,79 +2367,6 @@ class LlamaModel(TextModel):
|
||||||
if self.hf_arch == "VLlama3ForCausalLM":
|
if self.hf_arch == "VLlama3ForCausalLM":
|
||||||
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
|
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
|
||||||
|
|
||||||
def _set_vocab_mistral(self):
|
|
||||||
if not _mistral_common_installed:
|
|
||||||
raise ImportError(_mistral_import_error_msg)
|
|
||||||
|
|
||||||
vocab = MistralVocab(self.dir_model)
|
|
||||||
logger.info(
|
|
||||||
f"Converting tokenizer {vocab.tokenizer_type} of size {vocab.vocab_size}."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.gguf_writer.add_tokenizer_model(vocab.gguf_tokenizer_model)
|
|
||||||
|
|
||||||
tokens = []
|
|
||||||
scores = []
|
|
||||||
toktypes = []
|
|
||||||
|
|
||||||
for text, score, toktype in vocab.all_tokens():
|
|
||||||
tokens.append(text)
|
|
||||||
scores.append(score)
|
|
||||||
toktypes.append(toktype)
|
|
||||||
|
|
||||||
assert len(tokens) == vocab.vocab_size, (
|
|
||||||
f"token count ({len(tokens)}) != vocab size ({vocab.vocab_size})"
|
|
||||||
)
|
|
||||||
|
|
||||||
if vocab.tokenizer_type == MistralTokenizerType.tekken:
|
|
||||||
self.gguf_writer.add_tokenizer_pre("tekken")
|
|
||||||
self.gguf_writer.add_token_merges(
|
|
||||||
vocab.extract_vocab_merges_from_model()
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Setting bos, eos, unk and pad token IDs to {vocab.bos_id}, {vocab.eos_id}, {vocab.unk_id}, {vocab.pad_id}."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.gguf_writer.add_bos_token_id(vocab.bos_id)
|
|
||||||
self.gguf_writer.add_eos_token_id(vocab.eos_id)
|
|
||||||
self.gguf_writer.add_unk_token_id(vocab.unk_id)
|
|
||||||
self.gguf_writer.add_pad_token_id(vocab.pad_id)
|
|
||||||
|
|
||||||
self.gguf_writer.add_token_list(tokens)
|
|
||||||
self.gguf_writer.add_token_scores(scores)
|
|
||||||
self.gguf_writer.add_token_types(toktypes)
|
|
||||||
self.gguf_writer.add_vocab_size(vocab.vocab_size)
|
|
||||||
|
|
||||||
self.gguf_writer.add_add_bos_token(True)
|
|
||||||
self.gguf_writer.add_add_eos_token(False)
|
|
||||||
|
|
||||||
local_template_file_path = self.dir_model / "chat_template.jinja"
|
|
||||||
|
|
||||||
if self.is_mistral_format and local_template_file_path.is_file():
|
|
||||||
# Ministral-3 and other new Mistral models come with chat templates.
|
|
||||||
# ref: https://huggingface.co/mistralai/Ministral-3-14B-Instruct-2512/tree/main
|
|
||||||
logger.info("Using an existing Mistral local chat template.")
|
|
||||||
|
|
||||||
with open(local_template_file_path, "r", encoding="utf-8") as f:
|
|
||||||
template = f.read()
|
|
||||||
elif not self.is_mistral_format or not self.disable_mistral_community_chat_template:
|
|
||||||
template_dir = Path(__file__).parent / "models/templates/"
|
|
||||||
|
|
||||||
# Log only for Mistral format that the official tokenization and detokenization is via `mistral-common`.
|
|
||||||
if self.is_mistral_format:
|
|
||||||
logger.info(
|
|
||||||
"Using a Mistral community chat template. These templates can be subject to errors in early days or weeks after a release. "
|
|
||||||
"Mistral recommends to use `mistral-common` to perform tokenization and detokenization."
|
|
||||||
)
|
|
||||||
template = MistralModel.get_community_chat_template(vocab, template_dir, self.is_mistral_format)
|
|
||||||
else:
|
|
||||||
logger.info("Not using a Mistral local or community chat template. Ensure to perform the tokenization and detokenization via `mistral-common`.")
|
|
||||||
template = None
|
|
||||||
|
|
||||||
if template is not None:
|
|
||||||
self.gguf_writer.add_chat_template(template)
|
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
if self.is_mistral_format:
|
if self.is_mistral_format:
|
||||||
return self._set_vocab_mistral()
|
return self._set_vocab_mistral()
|
||||||
|
|
@ -9924,17 +9924,109 @@ class MistralModel(LlamaModel):
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
if "yarn" in self.hparams:
|
MistralModel.set_mistral_config(self.gguf_writer, self.hparams)
|
||||||
yarn_params = self.hparams["yarn"]
|
|
||||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
|
||||||
self.gguf_writer.add_rope_scaling_factor(yarn_params["factor"])
|
|
||||||
self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"])
|
|
||||||
self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"])
|
|
||||||
self.gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim
|
|
||||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"])
|
|
||||||
|
|
||||||
if "llama_4_scaling" in self.hparams:
|
@staticmethod
|
||||||
self.gguf_writer.add_attn_temperature_scale(self.hparams["llama_4_scaling"]["beta"])
|
def set_mistral_config(gguf_writer: gguf.GGUFWriter, hparams: dict):
|
||||||
|
if "yarn" in hparams:
|
||||||
|
yarn_params = hparams["yarn"]
|
||||||
|
gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||||
|
gguf_writer.add_rope_scaling_factor(yarn_params["factor"])
|
||||||
|
gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"])
|
||||||
|
gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"])
|
||||||
|
gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim
|
||||||
|
gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"])
|
||||||
|
|
||||||
|
if "llama_4_scaling" in hparams:
|
||||||
|
gguf_writer.add_attn_temperature_scale(hparams["llama_4_scaling"]["beta"])
|
||||||
|
|
||||||
|
|
||||||
|
class MistralMoeModel(DeepseekV2Model):
|
||||||
|
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
|
||||||
|
model_name = "Mistral"
|
||||||
|
hf_arch = ""
|
||||||
|
is_mistral_format = True
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
logger.info("Using MistralMoeModel")
|
||||||
|
# remap hparams from Mistral MoE format to DeepseekV2 format
|
||||||
|
# we do this way to be able to reuse DeepseekV2Model set_gguf_parameters logic
|
||||||
|
# ref: https://github.com/vllm-project/vllm/blob/b294e28db2c5dee61bc25157664edcada8b90b31/vllm/transformers_utils/configs/mistral.py
|
||||||
|
config = self.hparams
|
||||||
|
# Mistral key -> HF key
|
||||||
|
config_mapping = {
|
||||||
|
"dim": "hidden_size",
|
||||||
|
"norm_eps": "rms_norm_eps",
|
||||||
|
"n_kv_heads": "num_key_value_heads",
|
||||||
|
"n_layers": "num_hidden_layers",
|
||||||
|
"n_heads": "num_attention_heads",
|
||||||
|
"hidden_dim": "intermediate_size",
|
||||||
|
}
|
||||||
|
# HF key -> (Mistral key, default value)
|
||||||
|
top_level_mapping_with_default = {
|
||||||
|
"model_type": ("model_type", "transformer"),
|
||||||
|
"hidden_act": ("activation", "silu"),
|
||||||
|
"tie_word_embeddings": ("tied_embeddings", False),
|
||||||
|
"max_seq_len": ("max_seq_len", config.get("max_position_embeddings", 128_000)),
|
||||||
|
"max_position_embeddings": ("max_position_embeddings", 128_000),
|
||||||
|
}
|
||||||
|
# mapping top-level keys
|
||||||
|
for key, new_key in config_mapping.items():
|
||||||
|
if key in config:
|
||||||
|
config[new_key] = config[key]
|
||||||
|
for new_key, (key, default_value) in top_level_mapping_with_default.items():
|
||||||
|
config[new_key] = config.get(key, default_value)
|
||||||
|
# mapping MoE-specific keys
|
||||||
|
moe_config_map = {
|
||||||
|
"route_every_n": "moe_layer_freq",
|
||||||
|
"first_k_dense_replace": "first_k_dense_replace",
|
||||||
|
"num_experts_per_tok": "num_experts_per_tok",
|
||||||
|
"num_experts": "n_routed_experts",
|
||||||
|
"expert_hidden_dim": "moe_intermediate_size",
|
||||||
|
"routed_scale": "routed_scaling_factor",
|
||||||
|
"num_shared_experts": "n_shared_experts",
|
||||||
|
"num_expert_groups": "n_group",
|
||||||
|
"num_expert_groups_per_tok": "topk_group",
|
||||||
|
}
|
||||||
|
moe = config["moe"]
|
||||||
|
for key, new_key in moe_config_map.items():
|
||||||
|
if key in moe:
|
||||||
|
config[new_key] = moe[key]
|
||||||
|
# provide missing values
|
||||||
|
config["topk_method"] = None
|
||||||
|
config["norm_topk_prob"] = True
|
||||||
|
config["scoring_func"] = "softmax"
|
||||||
|
|
||||||
|
def set_vocab(self):
|
||||||
|
self._set_vocab_mistral()
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
MistralModel.set_mistral_config(self.gguf_writer, self.hparams)
|
||||||
|
yarn_params = self.hparams["yarn"]
|
||||||
|
self.gguf_writer.add_attn_temperature_length(yarn_params["original_max_position_embeddings"])
|
||||||
|
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1) # mscale_all_dim * 0.1
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||||
|
if name.startswith("vision_") or name.startswith("patch_merger.") or "mm_projector" in name:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# rename certain tensors so that we can reuse DeepseekV2Model modify_tensors logic
|
||||||
|
if name.endswith(".qscale_act"):
|
||||||
|
name = name.replace(".qscale_act", ".input_scale")
|
||||||
|
if name.endswith(".qscale_weight"):
|
||||||
|
name = name.replace(".qscale_weight", ".weight_scale")
|
||||||
|
if ".wkv_b." in name:
|
||||||
|
name = name.replace(".wkv_b.", ".kv_b_proj.")
|
||||||
|
if ".experts." in name:
|
||||||
|
name = name.replace(".experts.", ".mlp.experts.")
|
||||||
|
name = name.replace(".w1.", ".gate_proj.")
|
||||||
|
name = name.replace(".w2.", ".down_proj.")
|
||||||
|
name = name.replace(".w3.", ".up_proj.")
|
||||||
|
name = "model." + name
|
||||||
|
|
||||||
|
return super().modify_tensors(data_torch, name, bid)
|
||||||
|
|
||||||
|
|
||||||
class PixtralModel(LlavaVisionModel):
|
class PixtralModel(LlavaVisionModel):
|
||||||
|
|
@ -10490,6 +10582,8 @@ def main() -> None:
|
||||||
elif args.mmproj:
|
elif args.mmproj:
|
||||||
assert hparams.get("vision_encoder") is not None, "This model does not support multimodal"
|
assert hparams.get("vision_encoder") is not None, "This model does not support multimodal"
|
||||||
model_class = PixtralModel
|
model_class = PixtralModel
|
||||||
|
elif "moe" in hparams:
|
||||||
|
model_class = MistralMoeModel
|
||||||
else:
|
else:
|
||||||
model_class = MistralModel
|
model_class = MistralModel
|
||||||
|
|
||||||
|
|
|
||||||
216
docs/ops.md
216
docs/ops.md
|
|
@ -12,111 +12,111 @@ Legend:
|
||||||
- 🟡 Partially supported by this backend
|
- 🟡 Partially supported by this backend
|
||||||
- ❌ Not supported by this backend
|
- ❌ Not supported by this backend
|
||||||
|
|
||||||
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | zDNN |
|
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | WebGPU | zDNN |
|
||||||
|-----------|------|------|------|------|------|------|------|------|------|
|
|-----------|------|------|------|------|------|------|------|------|------|------|
|
||||||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ |
|
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||||
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ |
|
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||||
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| CUMSUM | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
| CUMSUM | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ |
|
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||||
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||||
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ |
|
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| FILL | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
| FILL | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
|
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ |
|
||||||
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ❌ |
|
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 |
|
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
|
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | ❌ |
|
||||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||||
| PAD | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
| PAD | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ |
|
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||||
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
| RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||||
| SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
| SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
|
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||||
| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | 🟡 | ❌ |
|
| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ |
|
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||||
| SUM | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
| SUM | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ |
|
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| TOP_K | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | 🟡 | ❌ |
|
| TOP_K | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||||
| TRI | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
| TRI | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -20,6 +20,7 @@ else()
|
||||||
|
|
||||||
add_subdirectory(gguf-hash)
|
add_subdirectory(gguf-hash)
|
||||||
add_subdirectory(gguf)
|
add_subdirectory(gguf)
|
||||||
|
add_subdirectory(idle)
|
||||||
add_subdirectory(lookahead)
|
add_subdirectory(lookahead)
|
||||||
add_subdirectory(lookup)
|
add_subdirectory(lookup)
|
||||||
add_subdirectory(parallel)
|
add_subdirectory(parallel)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
set(TARGET llama-idle)
|
||||||
|
add_executable(${TARGET} idle.cpp)
|
||||||
|
install(TARGETS ${TARGET} RUNTIME)
|
||||||
|
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
# llama.cpp/example/idle
|
||||||
|
|
||||||
|
https://github.com/ggml-org/llama.cpp/pull/17766
|
||||||
|
|
@ -0,0 +1,110 @@
|
||||||
|
#include "arg.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "log.h"
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstring>
|
||||||
|
#include <string>
|
||||||
|
#include <thread>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
static void print_usage(int /*argc*/, char ** argv) {
|
||||||
|
printf("\nexample usage:\n");
|
||||||
|
printf("\n %s -m model.gguf [-ngl n_gpu_layers]\n", argv[0]);
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char ** argv) {
|
||||||
|
common_params params;
|
||||||
|
|
||||||
|
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_init();
|
||||||
|
|
||||||
|
// init LLM
|
||||||
|
|
||||||
|
llama_backend_init();
|
||||||
|
llama_numa_init(params.numa);
|
||||||
|
|
||||||
|
// initialize the model
|
||||||
|
|
||||||
|
llama_model_params model_params = common_model_params_to_llama(params);
|
||||||
|
|
||||||
|
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
|
||||||
|
|
||||||
|
if (model == NULL) {
|
||||||
|
LOG_ERR("%s: error: unable to load model\n" , __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
|
||||||
|
// we need just a dummy token to evaluate
|
||||||
|
std::vector<llama_token> prompt_tokens(1, llama_vocab_bos(vocab));
|
||||||
|
|
||||||
|
llama_context_params ctx_params = llama_context_default_params();
|
||||||
|
ctx_params.n_ctx = 512;
|
||||||
|
ctx_params.n_batch = 512;
|
||||||
|
ctx_params.no_perf = false;
|
||||||
|
|
||||||
|
llama_context * ctx = llama_init_from_model(model, ctx_params);
|
||||||
|
if (ctx == NULL) {
|
||||||
|
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
|
||||||
|
|
||||||
|
const int n_iters = 3;
|
||||||
|
|
||||||
|
// warm-up
|
||||||
|
llama_decode(ctx, batch);
|
||||||
|
llama_memory_clear(llama_get_memory(ctx), true);
|
||||||
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
|
for (int64_t t_pause_ms = 0; t_pause_ms <= 4000; t_pause_ms += 800) {
|
||||||
|
double t_sum_us = 0.0;
|
||||||
|
double t_sum2_us = 0.0;
|
||||||
|
|
||||||
|
for (int i = 0; i < n_iters; i++) {
|
||||||
|
// this pause is important - it simulates "idle GPU"
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(t_pause_ms));
|
||||||
|
|
||||||
|
const int64_t t_start_us = llama_time_us();
|
||||||
|
|
||||||
|
// this should take constant time
|
||||||
|
llama_decode(ctx, batch);
|
||||||
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
|
const int64_t t_end_us = llama_time_us();
|
||||||
|
|
||||||
|
const double t_cur_us = t_end_us - t_start_us;
|
||||||
|
|
||||||
|
#if 1
|
||||||
|
// print individual decode times
|
||||||
|
printf(" - decode time: %8.2f ms\n", t_cur_us / 1000);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
t_sum_us += t_cur_us;
|
||||||
|
t_sum2_us += t_cur_us * t_cur_us;
|
||||||
|
|
||||||
|
llama_memory_clear(llama_get_memory(ctx), true);
|
||||||
|
llama_synchronize(ctx); // just in case
|
||||||
|
}
|
||||||
|
|
||||||
|
const double t_avg_us = t_sum_us / n_iters;
|
||||||
|
const double t_dev_us = sqrt((t_sum2_us / (n_iters - 1)) - (t_avg_us * t_avg_us * n_iters) / (n_iters - 1));
|
||||||
|
|
||||||
|
printf("iters: %4d, pause: %5d ms, avg decode time: %8.2f +/- %4.2f ms\n", n_iters, (int) t_pause_ms, t_avg_us / 1000, t_dev_us / 1000);
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_free(ctx);
|
||||||
|
llama_model_free(model);
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
@ -241,6 +241,12 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_free(batch);
|
||||||
|
|
||||||
|
// this one is managed by common_init_result
|
||||||
|
//llama_free(ctx);
|
||||||
|
|
||||||
|
llama_free(ctx2);
|
||||||
|
llama_free(ctx3);
|
||||||
|
|
||||||
if (result0 != result2) {
|
if (result0 != result2) {
|
||||||
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
|
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "ggml-backend.h"
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
|
@ -8,7 +7,7 @@ extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define RPC_PROTO_MAJOR_VERSION 3
|
#define RPC_PROTO_MAJOR_VERSION 3
|
||||||
#define RPC_PROTO_MINOR_VERSION 5
|
#define RPC_PROTO_MINOR_VERSION 6
|
||||||
#define RPC_PROTO_PATCH_VERSION 0
|
#define RPC_PROTO_PATCH_VERSION 0
|
||||||
#define GGML_RPC_MAX_SERVERS 16
|
#define GGML_RPC_MAX_SERVERS 16
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2196,6 +2196,15 @@ extern "C" {
|
||||||
int p2,
|
int p2,
|
||||||
int p3);
|
int p3);
|
||||||
|
|
||||||
|
// pad each dimension with values on the other side of the torus (looping around)
|
||||||
|
GGML_API struct ggml_tensor * ggml_pad_circular(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int p0,
|
||||||
|
int p1,
|
||||||
|
int p2,
|
||||||
|
int p3);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_pad_ext(
|
GGML_API struct ggml_tensor * ggml_pad_ext(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
|
@ -2209,6 +2218,19 @@ extern "C" {
|
||||||
int rp3
|
int rp3
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// pad each dimension with values on the other side of the torus (looping around)
|
||||||
|
GGML_API struct ggml_tensor * ggml_pad_ext_circular(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int lp0,
|
||||||
|
int rp0,
|
||||||
|
int lp1,
|
||||||
|
int rp1,
|
||||||
|
int lp2,
|
||||||
|
int rp2,
|
||||||
|
int lp3,
|
||||||
|
int rp3);
|
||||||
|
|
||||||
// pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
|
// pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
|
||||||
GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
|
GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
|
|
||||||
|
|
@ -534,8 +534,12 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||||
fs::path best_path;
|
fs::path best_path;
|
||||||
|
|
||||||
for (const auto & search_path : search_paths) {
|
for (const auto & search_path : search_paths) {
|
||||||
if (!fs::exists(search_path)) {
|
if (std::error_code ec; !fs::exists(search_path, ec)) {
|
||||||
|
if (ec) {
|
||||||
|
GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(search_path).c_str(), ec.message().c_str());
|
||||||
|
} else {
|
||||||
GGML_LOG_DEBUG("%s: search path %s does not exist\n", __func__, path_str(search_path).c_str());
|
GGML_LOG_DEBUG("%s: search path %s does not exist\n", __func__, path_str(search_path).c_str());
|
||||||
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||||
|
|
@ -575,8 +579,12 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||||
for (const auto & search_path : search_paths) {
|
for (const auto & search_path : search_paths) {
|
||||||
fs::path filename = backend_filename_prefix().native() + name_path.native() + backend_filename_extension().native();
|
fs::path filename = backend_filename_prefix().native() + name_path.native() + backend_filename_extension().native();
|
||||||
fs::path path = search_path / filename;
|
fs::path path = search_path / filename;
|
||||||
if (fs::exists(path)) {
|
if (std::error_code ec; fs::exists(path, ec)) {
|
||||||
return get_reg().load_backend(path, silent);
|
return get_reg().load_backend(path, silent);
|
||||||
|
} else {
|
||||||
|
if (ec) {
|
||||||
|
GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(path).c_str(), ec.message().c_str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
||||||
|
|
@ -2551,6 +2551,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
|
// TODO: add circular padding support for cann, see https://github.com/ggml-org/llama.cpp/pull/16985
|
||||||
|
return ggml_get_op_params_i32(op, 8) == 0;
|
||||||
case GGML_OP_ARANGE:
|
case GGML_OP_ARANGE:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
|
|
|
||||||
|
|
@ -6554,8 +6554,13 @@ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params
|
||||||
ggml_compute_forward_mul_mat(params, &dst);
|
ggml_compute_forward_mul_mat(params, &dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
|
||||||
|
return (coord + size) % size; // adding size avoids negative number weirdness
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_conv_2d
|
// ggml_compute_forward_conv_2d
|
||||||
|
|
||||||
|
|
||||||
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
|
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
|
||||||
const ggml_tensor * kernel, // [KW, KH, IC, OC]
|
const ggml_tensor * kernel, // [KW, KH, IC, OC]
|
||||||
const ggml_tensor * src, // [W, H, C, N]
|
const ggml_tensor * src, // [W, H, C, N]
|
||||||
|
|
@ -7591,6 +7596,7 @@ void ggml_compute_forward_upscale(
|
||||||
|
|
||||||
// ggml_compute_forward_pad
|
// ggml_compute_forward_pad
|
||||||
|
|
||||||
|
template<bool circular_t>
|
||||||
static void ggml_compute_forward_pad_f32(
|
static void ggml_compute_forward_pad_f32(
|
||||||
const ggml_compute_params * params,
|
const ggml_compute_params * params,
|
||||||
ggml_tensor * dst) {
|
ggml_tensor * dst) {
|
||||||
|
|
@ -7615,13 +7621,29 @@ static void ggml_compute_forward_pad_f32(
|
||||||
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
|
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
|
||||||
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
|
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
|
||||||
|
|
||||||
|
|
||||||
// TODO: optimize
|
// TODO: optimize
|
||||||
|
|
||||||
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
||||||
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
|
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
|
||||||
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
||||||
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||||||
|
// circular means wrap around on a torus, so x and y loop around
|
||||||
|
if constexpr (circular_t) {
|
||||||
|
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||||
|
const int64_t src_i0 = ggml_wrap_around(i0 - lp0, ne00);
|
||||||
|
const int64_t src_i1 = ggml_wrap_around(i1 - lp1, ne01);
|
||||||
|
const int64_t src_i2 = ggml_wrap_around(i2 - lp2, ne02);
|
||||||
|
const int64_t src_i3 = ggml_wrap_around(i3 - lp3, ne03);
|
||||||
|
|
||||||
|
const int64_t src_idx =
|
||||||
|
src_i3*nb03 +
|
||||||
|
src_i2*nb02 +
|
||||||
|
src_i1*nb01 +
|
||||||
|
src_i0*nb00;
|
||||||
|
|
||||||
|
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
|
||||||
|
dst_ptr[dst_idx] = *src_ptr;
|
||||||
|
} else {
|
||||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||||
if ((i0 >= lp0 && i0 < ne0 - rp0) \
|
if ((i0 >= lp0 && i0 < ne0 - rp0) \
|
||||||
&& (i1 >= lp1 && i1 < ne1 - rp1) \
|
&& (i1 >= lp1 && i1 < ne1 - rp1) \
|
||||||
|
|
@ -7637,18 +7659,23 @@ static void ggml_compute_forward_pad_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void ggml_compute_forward_pad(
|
void ggml_compute_forward_pad(
|
||||||
const ggml_compute_params * params,
|
const ggml_compute_params * params,
|
||||||
ggml_tensor * dst) {
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const bool circular = (bool) ggml_get_op_params_i32(dst, 8);
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_pad_f32(params, dst);
|
if (circular) {
|
||||||
|
ggml_compute_forward_pad_f32<true>(params, dst);
|
||||||
|
} else {
|
||||||
|
ggml_compute_forward_pad_f32<false>(params, dst);
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -560,7 +560,7 @@ namespace ggml_cuda_mma {
|
||||||
xi[0] = xs[0];
|
xi[0] = xs[0];
|
||||||
xi[1] = xs[1];
|
xi[1] = xs[1];
|
||||||
#endif // defined(RDNA4)
|
#endif // defined(RDNA4)
|
||||||
}else if constexpr (I == 16 && J == 8) {
|
} else if constexpr (I == 16 && J == 8) {
|
||||||
int64_t * xi = (int64_t *) t.x;
|
int64_t * xi = (int64_t *) t.x;
|
||||||
#if defined(RDNA4)
|
#if defined(RDNA4)
|
||||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
|
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
|
||||||
|
|
@ -577,14 +577,13 @@ namespace ggml_cuda_mma {
|
||||||
const int64_t * xs1 = xs + 2;
|
const int64_t * xs1 = xs + 2;
|
||||||
xi[2] = xs1[0];
|
xi[2] = xs1[0];
|
||||||
xi[3] = xs1[1];
|
xi[3] = xs1[1];
|
||||||
|
#endif // defined(RDNA4)
|
||||||
}else{
|
} else {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
}
|
}
|
||||||
#endif // defined(RDNA4)
|
|
||||||
#else
|
#else
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < t.ne; ++l) {
|
for (int l = 0; l < t.ne; ++l) {
|
||||||
|
|
|
||||||
|
|
@ -160,9 +160,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
return ampere_mma_available(cc);
|
return ampere_mma_available(cc);
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
|
return volta_mma_available(cc) || turing_mma_available(cc) || (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc));
|
||||||
case GGML_TYPE_BF16:
|
case GGML_TYPE_BF16:
|
||||||
return ampere_mma_available(cc) || amd_wmma_available(cc);
|
return ampere_mma_available(cc) || (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc));
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,17 @@
|
||||||
#include "pad.cuh"
|
#include "pad.cuh"
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
__device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) {
|
||||||
|
// + size ensures negatives are handled properly
|
||||||
|
return (coord + size) % size;
|
||||||
|
}
|
||||||
|
|
||||||
static __global__ void pad_f32(const float * src, float * dst,
|
static __global__ void pad_f32(const float * src, float * dst,
|
||||||
const int lp0, const int rp0, const int lp1, const int rp1,
|
const int lp0, const int rp0, const int lp1, const int rp1,
|
||||||
const int lp2, const int rp2, const int lp3, const int rp3,
|
const int lp2, const int rp2, const int lp3, const int rp3,
|
||||||
const int ne0, const int ne1, const int ne2, const int ne3) {
|
const int ne0, const int ne1, const int ne2, const int ne3,
|
||||||
|
const bool circular) {
|
||||||
// blockIdx.z: i3*ne2+i2
|
// blockIdx.z: i3*ne2+i2
|
||||||
// blockIdx.y: i1
|
// blockIdx.y: i1
|
||||||
// blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE
|
// blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE
|
||||||
|
|
@ -12,15 +20,15 @@ static __global__ void pad_f32(const float * src, float * dst,
|
||||||
int i1 = blockIdx.y;
|
int i1 = blockIdx.y;
|
||||||
int i2 = blockIdx.z % ne2;
|
int i2 = blockIdx.z % ne2;
|
||||||
int i3 = blockIdx.z / ne2;
|
int i3 = blockIdx.z / ne2;
|
||||||
|
|
||||||
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// operation
|
const int64_t dst_idx = i3 * (ne0 * ne1 * ne2) + i2 * (ne0 * ne1) + i1 * ne0 + i0;
|
||||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
|
||||||
if ((i0 >= lp0 && i0 < ne0 - rp0) &&
|
if (!circular) {
|
||||||
(i1 >= lp1 && i1 < ne1 - rp1) &&
|
if ((i0 >= lp0 && i0 < ne0 - rp0) && (i1 >= lp1 && i1 < ne1 - rp1) && (i2 >= lp2 && i2 < ne2 - rp2) &&
|
||||||
(i2 >= lp2 && i2 < ne2 - rp2) &&
|
|
||||||
(i3 >= lp3 && i3 < ne3 - rp3)) {
|
(i3 >= lp3 && i3 < ne3 - rp3)) {
|
||||||
const int64_t i00 = i0 - lp0;
|
const int64_t i00 = i0 - lp0;
|
||||||
const int64_t i01 = i1 - lp1;
|
const int64_t i01 = i1 - lp1;
|
||||||
|
|
@ -30,43 +38,66 @@ static __global__ void pad_f32(const float * src, float * dst,
|
||||||
const int64_t ne01 = ne1 - lp1 - rp1;
|
const int64_t ne01 = ne1 - lp1 - rp1;
|
||||||
const int64_t ne00 = ne0 - lp0 - rp0;
|
const int64_t ne00 = ne0 - lp0 - rp0;
|
||||||
|
|
||||||
const int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00;
|
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
|
||||||
|
|
||||||
dst[dst_idx] = src[src_idx];
|
dst[dst_idx] = src[src_idx];
|
||||||
} else {
|
} else {
|
||||||
dst[dst_idx] = 0.0f;
|
dst[dst_idx] = 0.0f;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
// circular means on a torus, so x and y wrap around
|
||||||
|
else {
|
||||||
|
const int64_t ne00 = ne0 - lp0 - rp0;
|
||||||
|
const int64_t ne01 = ne1 - lp1 - rp1;
|
||||||
|
const int64_t ne02 = ne2 - lp2 - rp2;
|
||||||
|
const int64_t ne03 = ne3 - lp3 - rp3;
|
||||||
|
|
||||||
|
const int64_t i00 = wrap_around(i0 - lp0, ne00);
|
||||||
|
const int64_t i01 = wrap_around(i1 - lp1, ne01);
|
||||||
|
const int64_t i02 = wrap_around(i2 - lp2, ne02);
|
||||||
|
const int64_t i03 = wrap_around(i3 - lp3, ne03);
|
||||||
|
|
||||||
|
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
|
||||||
|
|
||||||
|
dst[dst_idx] = src[src_idx];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static void pad_f32_cuda(const float * src, float * dst,
|
static void pad_f32_cuda(const float * src, float * dst,
|
||||||
const int lp0, const int rp0, const int lp1, const int rp1,
|
const int lp0, const int rp0, const int lp1, const int rp1,
|
||||||
const int lp2, const int rp2, const int lp3, const int rp3,
|
const int lp2, const int rp2, const int lp3, const int rp3,
|
||||||
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
|
const int ne0, const int ne1, const int ne2, const int ne3,
|
||||||
|
const bool circular, cudaStream_t stream) {
|
||||||
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
|
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
|
||||||
dim3 gridDim(num_blocks, ne1, ne2*ne3);
|
dim3 gridDim(num_blocks, ne1, ne2 * ne3);
|
||||||
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3);
|
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst,
|
||||||
|
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
||||||
|
ne0, ne1, ne2, ne3, circular);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const float * src0_d = (const float *)src0->data;
|
const float * src0_d = (const float *) src0->data;
|
||||||
float * dst_d = (float *)dst->data;
|
float * dst_d = (float *) dst->data;
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
|
const int32_t lp0 = ((const int32_t *) (dst->op_params))[0];
|
||||||
const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
|
const int32_t rp0 = ((const int32_t *) (dst->op_params))[1];
|
||||||
const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
|
const int32_t lp1 = ((const int32_t *) (dst->op_params))[2];
|
||||||
const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
|
const int32_t rp1 = ((const int32_t *) (dst->op_params))[3];
|
||||||
const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
|
const int32_t lp2 = ((const int32_t *) (dst->op_params))[4];
|
||||||
const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
|
const int32_t rp2 = ((const int32_t *) (dst->op_params))[5];
|
||||||
const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
|
const int32_t lp3 = ((const int32_t *) (dst->op_params))[6];
|
||||||
const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
|
const int32_t rp3 = ((const int32_t *) (dst->op_params))[7];
|
||||||
|
const int32_t circular = ((const int32_t *) (dst->op_params))[8];
|
||||||
|
|
||||||
pad_f32_cuda(src0_d, dst_d,
|
pad_f32_cuda(src0_d, dst_d,
|
||||||
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
||||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
|
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||||
|
(bool) circular, stream);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -24,9 +24,6 @@ struct ggml_metal_command_buffer {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_metal {
|
struct ggml_metal {
|
||||||
id<MTLDevice> device;
|
|
||||||
id<MTLCommandQueue> queue; // currently a pointer to the device queue, but might become separate queue [TAG_QUEUE_PER_BACKEND]
|
|
||||||
|
|
||||||
ggml_metal_device_t dev;
|
ggml_metal_device_t dev;
|
||||||
ggml_metal_library_t lib;
|
ggml_metal_library_t lib;
|
||||||
|
|
||||||
|
|
@ -91,15 +88,15 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
|
||||||
// init context
|
// init context
|
||||||
ggml_metal_t res = calloc(1, sizeof(struct ggml_metal));
|
ggml_metal_t res = calloc(1, sizeof(struct ggml_metal));
|
||||||
|
|
||||||
res->device = ggml_metal_device_get_obj(dev);
|
id<MTLDevice> device = ggml_metal_device_get_obj(dev);
|
||||||
|
|
||||||
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[res->device name] UTF8String]);
|
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
||||||
|
|
||||||
// TODO: would it be better to have one queue for the backend and one queue for the device?
|
// TODO: would it be better to have one queue for the backend and one queue for the device?
|
||||||
// the graph encoders and async ops would use the backend queue while the sync ops would use the device queue?
|
// the graph encoders and async ops would use the backend queue while the sync ops would use the device queue?
|
||||||
//res->queue = [device newCommandQueue]; [TAG_QUEUE_PER_BACKEND]
|
//res->queue = [device newCommandQueue]; [TAG_QUEUE_PER_BACKEND]
|
||||||
res->queue = ggml_metal_device_get_queue(dev);
|
id<MTLCommandQueue> queue = ggml_metal_device_get_queue(dev);
|
||||||
if (res->queue == nil) {
|
if (queue == nil) {
|
||||||
GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
|
GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
@ -274,7 +271,8 @@ static struct ggml_metal_buffer_id ggml_metal_get_buffer_id(const struct ggml_te
|
||||||
void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||||
@autoreleasepool {
|
@autoreleasepool {
|
||||||
// wrap the source data into a Metal buffer
|
// wrap the source data into a Metal buffer
|
||||||
id<MTLBuffer> buf_src = [ctx->device newBufferWithBytes:data
|
id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);
|
||||||
|
id<MTLBuffer> buf_src = [device newBufferWithBytes:data
|
||||||
length:size
|
length:size
|
||||||
options:MTLResourceStorageModeShared];
|
options:MTLResourceStorageModeShared];
|
||||||
|
|
||||||
|
|
@ -289,7 +287,8 @@ void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor,
|
||||||
|
|
||||||
// queue the copy operation into the queue of the Metal context
|
// queue the copy operation into the queue of the Metal context
|
||||||
// this will be queued at the end, after any currently ongoing GPU operations
|
// this will be queued at the end, after any currently ongoing GPU operations
|
||||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
|
id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);
|
||||||
|
id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];
|
||||||
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
||||||
|
|
||||||
[encoder copyFromBuffer:buf_src
|
[encoder copyFromBuffer:buf_src
|
||||||
|
|
@ -315,7 +314,8 @@ void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor,
|
||||||
|
|
||||||
void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||||
@autoreleasepool {
|
@autoreleasepool {
|
||||||
id<MTLBuffer> buf_dst = [ctx->device newBufferWithBytesNoCopy:data
|
id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);
|
||||||
|
id<MTLBuffer> buf_dst = [device newBufferWithBytesNoCopy:data
|
||||||
length:size
|
length:size
|
||||||
options:MTLResourceStorageModeShared
|
options:MTLResourceStorageModeShared
|
||||||
deallocator:nil];
|
deallocator:nil];
|
||||||
|
|
@ -331,7 +331,8 @@ void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * te
|
||||||
|
|
||||||
// queue the copy operation into the queue of the Metal context
|
// queue the copy operation into the queue of the Metal context
|
||||||
// this will be queued at the end, after any currently ongoing GPU operations
|
// this will be queued at the end, after any currently ongoing GPU operations
|
||||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
|
id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);
|
||||||
|
id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];
|
||||||
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
||||||
|
|
||||||
[encoder copyFromBuffer:bid_src.metal
|
[encoder copyFromBuffer:bid_src.metal
|
||||||
|
|
@ -362,6 +363,9 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
|
||||||
// number of threads in addition to the main thread
|
// number of threads in addition to the main thread
|
||||||
const int n_cb = ctx->n_cb;
|
const int n_cb = ctx->n_cb;
|
||||||
|
|
||||||
|
// keep the memory wired
|
||||||
|
ggml_metal_device_rsets_keep_alive(ctx->dev);
|
||||||
|
|
||||||
// submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
|
// submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
|
||||||
// the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
|
// the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
|
||||||
// while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
|
// while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
|
||||||
|
|
@ -389,7 +393,8 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
|
||||||
|
|
||||||
if (!ctx->capture_started) {
|
if (!ctx->capture_started) {
|
||||||
// create capture scope
|
// create capture scope
|
||||||
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device];
|
id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);
|
||||||
|
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:device];
|
||||||
|
|
||||||
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
|
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
|
||||||
descriptor.captureObject = ctx->capture_scope;
|
descriptor.captureObject = ctx->capture_scope;
|
||||||
|
|
@ -406,10 +411,13 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// short-hand
|
||||||
|
id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);
|
||||||
|
|
||||||
// the main thread commits the first few commands immediately
|
// the main thread commits the first few commands immediately
|
||||||
// cmd_buf[n_cb]
|
// cmd_buf[n_cb]
|
||||||
{
|
{
|
||||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences];
|
||||||
[cmd_buf retain];
|
[cmd_buf retain];
|
||||||
|
|
||||||
if (ctx->cmd_bufs[n_cb].obj) {
|
if (ctx->cmd_bufs[n_cb].obj) {
|
||||||
|
|
@ -428,7 +436,7 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
|
||||||
// prepare the rest of the command buffers asynchronously (optional)
|
// prepare the rest of the command buffers asynchronously (optional)
|
||||||
// cmd_buf[0.. n_cb)
|
// cmd_buf[0.. n_cb)
|
||||||
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
||||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences];
|
||||||
[cmd_buf retain];
|
[cmd_buf retain];
|
||||||
|
|
||||||
if (ctx->cmd_bufs[cb_idx].obj) {
|
if (ctx->cmd_bufs[cb_idx].obj) {
|
||||||
|
|
@ -589,9 +597,11 @@ void ggml_metal_set_abort_callback(ggml_metal_t ctx, ggml_abort_callback abort_c
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_metal_supports_family(ggml_metal_t ctx, int family) {
|
bool ggml_metal_supports_family(ggml_metal_t ctx, int family) {
|
||||||
GGML_ASSERT(ctx->device != nil);
|
GGML_ASSERT(ctx->dev != nil);
|
||||||
|
|
||||||
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);
|
||||||
|
|
||||||
|
return [device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_metal_capture_next_compute(ggml_metal_t ctx) {
|
void ggml_metal_capture_next_compute(ggml_metal_t ctx) {
|
||||||
|
|
|
||||||
|
|
@ -186,6 +186,16 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_att
|
||||||
int32_t dv,
|
int32_t dv,
|
||||||
int32_t nwg);
|
int32_t nwg);
|
||||||
|
|
||||||
|
// MTLResidencySet wrapper
|
||||||
|
|
||||||
|
typedef void * ggml_metal_rset_t;
|
||||||
|
|
||||||
|
// a collection of residency sets (non-owning)
|
||||||
|
typedef struct ggml_metal_rsets * ggml_metal_rsets_t;
|
||||||
|
|
||||||
|
ggml_metal_rsets_t ggml_metal_rsets_init(void);
|
||||||
|
void ggml_metal_rsets_free(ggml_metal_rsets_t rsets);
|
||||||
|
|
||||||
//
|
//
|
||||||
// device
|
// device
|
||||||
//
|
//
|
||||||
|
|
@ -219,6 +229,11 @@ void * ggml_metal_device_get_queue(ggml_metal_device_t dev); // id<MTLCommandQue
|
||||||
|
|
||||||
ggml_metal_library_t ggml_metal_device_get_library(ggml_metal_device_t dev);
|
ggml_metal_library_t ggml_metal_device_get_library(ggml_metal_device_t dev);
|
||||||
|
|
||||||
|
void ggml_metal_device_rsets_add(ggml_metal_device_t dev, ggml_metal_rset_t rset);
|
||||||
|
void ggml_metal_device_rsets_rm (ggml_metal_device_t dev, ggml_metal_rset_t rset);
|
||||||
|
|
||||||
|
void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev);
|
||||||
|
|
||||||
void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total);
|
void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total);
|
||||||
bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op);
|
bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
#import "ggml-metal-device.h"
|
#import "ggml-metal-device.h"
|
||||||
|
|
||||||
#import "ggml-impl.h"
|
#import "ggml-impl.h"
|
||||||
#import "ggml-threading.h"
|
|
||||||
|
|
||||||
#include <Foundation/Foundation.h>
|
#include <Foundation/Foundation.h>
|
||||||
|
|
||||||
|
|
@ -519,11 +518,106 @@ struct ggml_metal_device {
|
||||||
// ref: https://github.com/ggml-org/llama.cpp/pull/15906
|
// ref: https://github.com/ggml-org/llama.cpp/pull/15906
|
||||||
id<MTLCommandQueue> mtl_queue;
|
id<MTLCommandQueue> mtl_queue;
|
||||||
|
|
||||||
|
ggml_metal_rsets_t rsets;
|
||||||
|
|
||||||
ggml_metal_library_t library;
|
ggml_metal_library_t library;
|
||||||
|
|
||||||
struct ggml_metal_device_props props;
|
struct ggml_metal_device_props props;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// MTLResidenceSet wrapper
|
||||||
|
//
|
||||||
|
|
||||||
|
struct ggml_metal_rsets {
|
||||||
|
NSLock * lock;
|
||||||
|
|
||||||
|
NSMutableArray * data;
|
||||||
|
|
||||||
|
// number of seconds since the last graph computation
|
||||||
|
// keep the residency sets wired for that amount of time to avoid being collected by the OS
|
||||||
|
int keep_alive_s;
|
||||||
|
|
||||||
|
// background heartbeat thread to keep the residency sets alive
|
||||||
|
atomic_bool d_stop;
|
||||||
|
atomic_int d_loop;
|
||||||
|
|
||||||
|
dispatch_group_t d_group;
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_metal_rsets_t ggml_metal_rsets_init(void) {
|
||||||
|
ggml_metal_rsets_t res = calloc(1, sizeof(struct ggml_metal_rsets));
|
||||||
|
|
||||||
|
res->lock = [[NSLock alloc] init];
|
||||||
|
res->data = [[NSMutableArray alloc] init];
|
||||||
|
|
||||||
|
// by default keep the memory wired for 3 minutes
|
||||||
|
res->keep_alive_s = 3*60;
|
||||||
|
|
||||||
|
const char * GGML_METAL_RESIDENCY_KEEP_ALIVE_S = getenv("GGML_METAL_RESIDENCY_KEEP_ALIVE_S");
|
||||||
|
if (GGML_METAL_RESIDENCY_KEEP_ALIVE_S) {
|
||||||
|
res->keep_alive_s = atoi(GGML_METAL_RESIDENCY_KEEP_ALIVE_S);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (res->keep_alive_s <= 0) {
|
||||||
|
res->keep_alive_s = 3*60;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_LOG_INFO("%s: creating a residency set collection (keep_alive = %d s)\n", __func__, res->keep_alive_s);
|
||||||
|
|
||||||
|
atomic_store_explicit(&res->d_stop, false, memory_order_relaxed);
|
||||||
|
atomic_store_explicit(&res->d_loop, 2*res->keep_alive_s, memory_order_relaxed);
|
||||||
|
|
||||||
|
res->d_group = dispatch_group_create();
|
||||||
|
|
||||||
|
// start a background thread that periodically requests residency for all the currently active sets in the collection
|
||||||
|
// the requests stop after a certain amount of time (keep_alive_s) of inactivity
|
||||||
|
dispatch_queue_t d_queue = dispatch_get_global_queue(QOS_CLASS_DEFAULT, 0);
|
||||||
|
dispatch_group_async(res->d_group, d_queue, ^{
|
||||||
|
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
|
||||||
|
if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
|
||||||
|
while (!atomic_load_explicit(&res->d_stop, memory_order_relaxed)) {
|
||||||
|
if (atomic_load_explicit(&res->d_loop, memory_order_relaxed) > 0) {
|
||||||
|
[res->lock lock];
|
||||||
|
|
||||||
|
for (int i = 0; i < (int) res->data.count; ++i) {
|
||||||
|
[res->data[i] requestResidency];
|
||||||
|
}
|
||||||
|
|
||||||
|
atomic_fetch_sub_explicit(&res->d_loop, 1, memory_order_relaxed);
|
||||||
|
|
||||||
|
[res->lock unlock];
|
||||||
|
}
|
||||||
|
|
||||||
|
// half a second
|
||||||
|
usleep(500 * 1000);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
});
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_metal_rsets_free(ggml_metal_rsets_t rsets) {
|
||||||
|
if (rsets == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// note: if you hit this assert, most likely you haven't deallocated all Metal resources before exiting
|
||||||
|
GGML_ASSERT([rsets->data count] == 0);
|
||||||
|
|
||||||
|
atomic_store_explicit(&rsets->d_stop, true, memory_order_relaxed);
|
||||||
|
|
||||||
|
dispatch_group_wait(rsets->d_group, DISPATCH_TIME_FOREVER);
|
||||||
|
dispatch_release(rsets->d_group);
|
||||||
|
|
||||||
|
[rsets->data release];
|
||||||
|
[rsets->lock release];
|
||||||
|
|
||||||
|
free(rsets);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_metal_device_t ggml_metal_device_init(void) {
|
ggml_metal_device_t ggml_metal_device_init(void) {
|
||||||
ggml_metal_device_t dev = calloc(1, sizeof(struct ggml_metal_device));
|
ggml_metal_device_t dev = calloc(1, sizeof(struct ggml_metal_device));
|
||||||
|
|
||||||
|
|
@ -692,7 +786,11 @@ ggml_metal_device_t ggml_metal_device_init(void) {
|
||||||
GGML_LOG_ERROR("%s: error: failed to create library\n", __func__);
|
GGML_LOG_ERROR("%s: error: failed to create library\n", __func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
// --------------------------------------------------
|
if (dev->props.use_residency_sets) {
|
||||||
|
dev->rsets = ggml_metal_rsets_init();
|
||||||
|
} else {
|
||||||
|
dev->rsets = nil;
|
||||||
|
}
|
||||||
|
|
||||||
// print MTL GPU family:
|
// print MTL GPU family:
|
||||||
GGML_LOG_INFO("%s: GPU name: %s\n", __func__, dev->props.name);
|
GGML_LOG_INFO("%s: GPU name: %s\n", __func__, dev->props.name);
|
||||||
|
|
@ -745,6 +843,8 @@ ggml_metal_device_t ggml_metal_device_init(void) {
|
||||||
void ggml_metal_device_free(ggml_metal_device_t dev) {
|
void ggml_metal_device_free(ggml_metal_device_t dev) {
|
||||||
assert(dev != NULL);
|
assert(dev != NULL);
|
||||||
|
|
||||||
|
ggml_metal_rsets_free(dev->rsets);
|
||||||
|
|
||||||
ggml_metal_library_free(dev->library);
|
ggml_metal_library_free(dev->library);
|
||||||
dev->library = NULL;
|
dev->library = NULL;
|
||||||
|
|
||||||
|
|
@ -773,6 +873,42 @@ ggml_metal_library_t ggml_metal_device_get_library(ggml_metal_device_t dev) {
|
||||||
return dev->library;
|
return dev->library;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_metal_device_rsets_add(ggml_metal_device_t dev, ggml_metal_rset_t rset) {
|
||||||
|
if (rset == nil) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(dev->rsets);
|
||||||
|
|
||||||
|
[dev->rsets->lock lock];
|
||||||
|
|
||||||
|
[dev->rsets->data addObject:rset];
|
||||||
|
|
||||||
|
[dev->rsets->lock unlock];
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_metal_device_rsets_rm(ggml_metal_device_t dev, ggml_metal_rset_t rset) {
|
||||||
|
if (rset == nil) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(dev->rsets);
|
||||||
|
|
||||||
|
[dev->rsets->lock lock];
|
||||||
|
|
||||||
|
[dev->rsets->data removeObject:rset];
|
||||||
|
|
||||||
|
[dev->rsets->lock unlock];
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev) {
|
||||||
|
if (dev->rsets == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
atomic_store_explicit(&dev->rsets->d_loop, 2*dev->rsets->keep_alive_s, memory_order_relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total) {
|
void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total) {
|
||||||
if (@available(macOS 10.12, iOS 16.0, *)) {
|
if (@available(macOS 10.12, iOS 16.0, *)) {
|
||||||
*total = dev->mtl_device.recommendedMaxWorkingSetSize;
|
*total = dev->mtl_device.recommendedMaxWorkingSetSize;
|
||||||
|
|
@ -901,6 +1037,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
return op->src[0]->type == GGML_TYPE_F32;
|
return op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
|
// TODO: add circular padding support for metal, see https://github.com/ggml-org/llama.cpp/pull/16985
|
||||||
|
if (ggml_get_op_params_i32(op, 8) != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
|
return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
|
||||||
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
|
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
|
||||||
case GGML_OP_PAD_REFLECT_1D:
|
case GGML_OP_PAD_REFLECT_1D:
|
||||||
|
|
@ -1066,9 +1207,8 @@ struct ggml_metal_buffer {
|
||||||
// note: cannot use explicity "id<MTLResidencySet>" here because it is not available on certain OSes
|
// note: cannot use explicity "id<MTLResidencySet>" here because it is not available on certain OSes
|
||||||
id rset;
|
id rset;
|
||||||
|
|
||||||
// pointers to global device objects
|
// pointers to global device
|
||||||
id<MTLDevice> device;
|
ggml_metal_device_t dev;
|
||||||
id<MTLCommandQueue> queue;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
static void ggml_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
|
static void ggml_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
|
||||||
|
|
@ -1111,7 +1251,7 @@ static bool ggml_metal_buffer_rset_init(ggml_metal_buffer_t buf) {
|
||||||
desc.initialCapacity = buf->n_buffers;
|
desc.initialCapacity = buf->n_buffers;
|
||||||
|
|
||||||
NSError * error;
|
NSError * error;
|
||||||
buf->rset = [buf->device newResidencySetWithDescriptor:desc error:&error];
|
buf->rset = [buf->dev->mtl_device newResidencySetWithDescriptor:desc error:&error];
|
||||||
if (error) {
|
if (error) {
|
||||||
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||||
[desc release];
|
[desc release];
|
||||||
|
|
@ -1172,6 +1312,8 @@ static void * ggml_metal_host_malloc(size_t n) {
|
||||||
ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, bool shared) {
|
ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, bool shared) {
|
||||||
ggml_metal_buffer_t res = calloc(1, sizeof(struct ggml_metal_buffer));
|
ggml_metal_buffer_t res = calloc(1, sizeof(struct ggml_metal_buffer));
|
||||||
|
|
||||||
|
res->dev = dev;
|
||||||
|
|
||||||
const size_t size_page = sysconf(_SC_PAGESIZE);
|
const size_t size_page = sysconf(_SC_PAGESIZE);
|
||||||
|
|
||||||
size_t size_aligned = size;
|
size_t size_aligned = size;
|
||||||
|
|
@ -1196,9 +1338,6 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
|
||||||
|
|
||||||
res->owned = true;
|
res->owned = true;
|
||||||
|
|
||||||
res->device = ggml_metal_device_get_obj(dev);
|
|
||||||
res->queue = ggml_metal_device_get_queue(dev);
|
|
||||||
|
|
||||||
res->n_buffers = 1;
|
res->n_buffers = 1;
|
||||||
|
|
||||||
if (res->all_data != NULL) {
|
if (res->all_data != NULL) {
|
||||||
|
|
@ -1207,12 +1346,12 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
|
||||||
|
|
||||||
if (size_aligned > 0) {
|
if (size_aligned > 0) {
|
||||||
if (props_dev->use_shared_buffers && shared) {
|
if (props_dev->use_shared_buffers && shared) {
|
||||||
res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data
|
res->buffers[0].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:res->all_data
|
||||||
length:size_aligned
|
length:size_aligned
|
||||||
options:MTLResourceStorageModeShared
|
options:MTLResourceStorageModeShared
|
||||||
deallocator:nil];
|
deallocator:nil];
|
||||||
} else {
|
} else {
|
||||||
res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];
|
res->buffers[0].metal = [res->dev->mtl_device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1233,6 +1372,8 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_metal_device_rsets_add(dev, res->rset);
|
||||||
|
|
||||||
//ggml_metal_log_allocated_size(device, size_aligned);
|
//ggml_metal_log_allocated_size(device, size_aligned);
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
|
|
@ -1241,6 +1382,8 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
|
||||||
ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
||||||
ggml_metal_buffer_t res = calloc(1, sizeof(struct ggml_metal_buffer));
|
ggml_metal_buffer_t res = calloc(1, sizeof(struct ggml_metal_buffer));
|
||||||
|
|
||||||
|
res->dev = dev;
|
||||||
|
|
||||||
res->all_data = ptr;
|
res->all_data = ptr;
|
||||||
res->all_size = size;
|
res->all_size = size;
|
||||||
|
|
||||||
|
|
@ -1263,9 +1406,6 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s
|
||||||
size_aligned += (size_page - (size_aligned % size_page));
|
size_aligned += (size_page - (size_aligned % size_page));
|
||||||
}
|
}
|
||||||
|
|
||||||
res->device = ggml_metal_device_get_obj(dev);
|
|
||||||
res->queue = ggml_metal_device_get_queue(dev);
|
|
||||||
|
|
||||||
const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);
|
const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);
|
||||||
|
|
||||||
// the buffer fits into the max buffer size allowed by the device
|
// the buffer fits into the max buffer size allowed by the device
|
||||||
|
|
@ -1275,7 +1415,7 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s
|
||||||
res->buffers[res->n_buffers].metal = nil;
|
res->buffers[res->n_buffers].metal = nil;
|
||||||
|
|
||||||
if (size_aligned > 0) {
|
if (size_aligned > 0) {
|
||||||
res->buffers[res->n_buffers].metal = [res->device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
res->buffers[res->n_buffers].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
||||||
|
|
||||||
if (res->buffers[res->n_buffers].metal == nil) {
|
if (res->buffers[res->n_buffers].metal == nil) {
|
||||||
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
||||||
|
|
@ -1284,7 +1424,7 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_metal_log_allocated_size(res->device, size_aligned);
|
ggml_metal_log_allocated_size(res->dev->mtl_device, size_aligned);
|
||||||
|
|
||||||
++res->n_buffers;
|
++res->n_buffers;
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -1302,7 +1442,7 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s
|
||||||
res->buffers[res->n_buffers].metal = nil;
|
res->buffers[res->n_buffers].metal = nil;
|
||||||
|
|
||||||
if (size_step_aligned > 0) {
|
if (size_step_aligned > 0) {
|
||||||
res->buffers[res->n_buffers].metal = [res->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
res->buffers[res->n_buffers].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
||||||
|
|
||||||
if (res->buffers[res->n_buffers].metal == nil) {
|
if (res->buffers[res->n_buffers].metal == nil) {
|
||||||
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
|
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
|
||||||
|
|
@ -1311,7 +1451,7 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_metal_log_allocated_size(res->device, size_step_aligned);
|
ggml_metal_log_allocated_size(res->dev->mtl_device, size_step_aligned);
|
||||||
|
|
||||||
if (i + size_step < size) {
|
if (i + size_step < size) {
|
||||||
GGML_LOG_INFO("\n");
|
GGML_LOG_INFO("\n");
|
||||||
|
|
@ -1329,10 +1469,14 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_metal_device_rsets_add(dev, res->rset);
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_metal_buffer_free(ggml_metal_buffer_t buf) {
|
void ggml_metal_buffer_free(ggml_metal_buffer_t buf) {
|
||||||
|
ggml_metal_device_rsets_rm(buf->dev, buf->rset);
|
||||||
|
|
||||||
for (int i = 0; i < buf->n_buffers; i++) {
|
for (int i = 0; i < buf->n_buffers; i++) {
|
||||||
[buf->buffers[i].metal release];
|
[buf->buffers[i].metal release];
|
||||||
}
|
}
|
||||||
|
|
@ -1369,8 +1513,7 @@ void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor
|
||||||
struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf, tensor);
|
struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf, tensor);
|
||||||
bid_dst.offs += offset;
|
bid_dst.offs += offset;
|
||||||
|
|
||||||
id<MTLCommandQueue> queue = buf->queue;
|
id<MTLCommandBuffer> cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences];
|
||||||
id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences];
|
|
||||||
|
|
||||||
{
|
{
|
||||||
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
||||||
|
|
@ -1396,7 +1539,7 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor *
|
||||||
@autoreleasepool {
|
@autoreleasepool {
|
||||||
// src
|
// src
|
||||||
void * data_ptr = (void *)(uintptr_t) data; // "const cast" the src data
|
void * data_ptr = (void *)(uintptr_t) data; // "const cast" the src data
|
||||||
id<MTLBuffer> buf_src = [buf->device newBufferWithBytesNoCopy:data_ptr
|
id<MTLBuffer> buf_src = [buf->dev->mtl_device newBufferWithBytesNoCopy:data_ptr
|
||||||
length:size
|
length:size
|
||||||
options:MTLResourceStorageModeShared
|
options:MTLResourceStorageModeShared
|
||||||
deallocator:nil];
|
deallocator:nil];
|
||||||
|
|
@ -1411,8 +1554,7 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor *
|
||||||
// this is alternative to waitUntilCompleted, which should be faster, but don't seem to make much difference
|
// this is alternative to waitUntilCompleted, which should be faster, but don't seem to make much difference
|
||||||
dispatch_semaphore_t completion_semaphore = dispatch_semaphore_create(0);
|
dispatch_semaphore_t completion_semaphore = dispatch_semaphore_create(0);
|
||||||
|
|
||||||
id<MTLCommandQueue> queue = buf->queue;
|
id<MTLCommandBuffer> cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences];
|
||||||
id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences];
|
|
||||||
|
|
||||||
{
|
{
|
||||||
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
||||||
|
|
@ -1454,15 +1596,14 @@ void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_ten
|
||||||
bid_src.offs += offset;
|
bid_src.offs += offset;
|
||||||
|
|
||||||
// dst
|
// dst
|
||||||
id<MTLBuffer> buf_dst = [buf->device newBufferWithBytesNoCopy:data
|
id<MTLBuffer> buf_dst = [buf->dev->mtl_device newBufferWithBytesNoCopy:data
|
||||||
length:size
|
length:size
|
||||||
options:MTLResourceStorageModeShared
|
options:MTLResourceStorageModeShared
|
||||||
deallocator:nil];
|
deallocator:nil];
|
||||||
|
|
||||||
GGML_ASSERT(buf_dst);
|
GGML_ASSERT(buf_dst);
|
||||||
|
|
||||||
id<MTLCommandQueue> queue = buf->queue;
|
id<MTLCommandBuffer> cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences];
|
||||||
id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences];
|
|
||||||
|
|
||||||
{
|
{
|
||||||
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
||||||
|
|
@ -1488,8 +1629,7 @@ void ggml_metal_buffer_clear(ggml_metal_buffer_t buf, uint8_t value) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@autoreleasepool {
|
@autoreleasepool {
|
||||||
id<MTLCommandQueue> queue = buf->queue;
|
id<MTLCommandBuffer> cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences];
|
||||||
id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences];
|
|
||||||
|
|
||||||
{
|
{
|
||||||
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
||||||
|
|
|
||||||
|
|
@ -3083,6 +3083,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
|
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
|
// TODO: add circular padding support for opencl, see https://github.com/ggml-org/llama.cpp/pull/16985
|
||||||
|
if (ggml_get_op_params_i32(op, 8) != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_UPSCALE: {
|
case GGML_OP_UPSCALE: {
|
||||||
ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF);
|
ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF);
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,7 @@ struct rpc_msg_device_count_rsp {
|
||||||
struct rpc_msg_get_alloc_size_req {
|
struct rpc_msg_get_alloc_size_req {
|
||||||
uint32_t device;
|
uint32_t device;
|
||||||
rpc_tensor tensor;
|
rpc_tensor tensor;
|
||||||
|
rpc_tensor srcs[GGML_MAX_SRC];
|
||||||
};
|
};
|
||||||
|
|
||||||
struct rpc_msg_get_alloc_size_rsp {
|
struct rpc_msg_get_alloc_size_rsp {
|
||||||
|
|
@ -572,6 +573,11 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
|
|
||||||
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
||||||
rpc_tensor result;
|
rpc_tensor result;
|
||||||
|
if (!tensor) {
|
||||||
|
memset(&result, 0, sizeof(result));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
result.id = reinterpret_cast<uint64_t>(tensor);
|
result.id = reinterpret_cast<uint64_t>(tensor);
|
||||||
result.type = tensor->type;
|
result.type = tensor->type;
|
||||||
if (tensor->buffer) {
|
if (tensor->buffer) {
|
||||||
|
|
@ -753,23 +759,41 @@ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
||||||
|
// should we query the remote server for the actual size
|
||||||
|
bool rpc_get = false;
|
||||||
|
|
||||||
// See comments in init_tensor.
|
// See comments in init_tensor.
|
||||||
if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
|
rpc_get |= ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr);
|
||||||
|
|
||||||
|
// ops that require additional memory for fleeting data on certain backends
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/15966
|
||||||
|
rpc_get |= tensor->op == GGML_OP_FLASH_ATTN_EXT;
|
||||||
|
rpc_get |= tensor->op == GGML_OP_MUL_MAT_ID;
|
||||||
|
|
||||||
|
if (rpc_get) {
|
||||||
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
||||||
auto sock = get_socket(buft_ctx->endpoint);
|
auto sock = get_socket(buft_ctx->endpoint);
|
||||||
|
|
||||||
rpc_msg_get_alloc_size_req request;
|
rpc_msg_get_alloc_size_req request = {
|
||||||
request.device = buft_ctx->device;
|
/*.device =*/ buft_ctx->device,
|
||||||
request.tensor = serialize_tensor(tensor);
|
/*.tensor =*/ serialize_tensor(tensor),
|
||||||
|
/*.srcs =*/ {},
|
||||||
|
};
|
||||||
|
|
||||||
|
// .get_alloc_size could be a function of the tensor's srcs, so we must serialize them as well
|
||||||
|
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||||
|
request.srcs[i] = serialize_tensor(tensor->src[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: cache the alloc responses to avoid extra RPC calls?
|
||||||
rpc_msg_get_alloc_size_rsp response;
|
rpc_msg_get_alloc_size_rsp response;
|
||||||
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
|
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
|
||||||
RPC_STATUS_ASSERT(status);
|
RPC_STATUS_ASSERT(status);
|
||||||
|
|
||||||
return response.alloc_size;
|
return response.alloc_size;
|
||||||
} else {
|
|
||||||
return ggml_nbytes(tensor);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return ggml_nbytes(tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
||||||
|
|
@ -1017,7 +1041,7 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
|
||||||
}
|
}
|
||||||
ggml_backend_buffer_type_t buft;
|
ggml_backend_buffer_type_t buft;
|
||||||
struct ggml_init_params params {
|
struct ggml_init_params params {
|
||||||
/*.mem_size =*/ ggml_tensor_overhead(),
|
/*.mem_size =*/ ggml_tensor_overhead()*(1 + GGML_MAX_SRC),
|
||||||
/*.mem_buffer =*/ NULL,
|
/*.mem_buffer =*/ NULL,
|
||||||
/*.no_alloc =*/ true,
|
/*.no_alloc =*/ true,
|
||||||
};
|
};
|
||||||
|
|
@ -1025,12 +1049,18 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
|
||||||
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
||||||
GGML_ASSERT(ctx_ptr != nullptr);
|
GGML_ASSERT(ctx_ptr != nullptr);
|
||||||
ggml_context * ctx = ctx_ptr.get();
|
ggml_context * ctx = ctx_ptr.get();
|
||||||
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
|
||||||
|
|
||||||
|
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
||||||
if (tensor == nullptr) {
|
if (tensor == nullptr) {
|
||||||
GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
|
GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||||
|
if (request.srcs[i].id != 0) {
|
||||||
|
tensor->src[i] = deserialize_tensor(ctx, &request.srcs[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data);
|
LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data);
|
||||||
if (tensor->buffer == nullptr) {
|
if (tensor->buffer == nullptr) {
|
||||||
//No buffer allocated.
|
//No buffer allocated.
|
||||||
|
|
@ -1227,7 +1257,8 @@ bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
|
||||||
char hash_str[17];
|
char hash_str[17];
|
||||||
snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
|
snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
|
||||||
fs::path cache_file = fs::path(cache_dir) / hash_str;
|
fs::path cache_file = fs::path(cache_dir) / hash_str;
|
||||||
if (!fs::exists(cache_file)) {
|
std::error_code ec;
|
||||||
|
if (!fs::exists(cache_file, ec)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
std::ifstream ifs(cache_file, std::ios::binary);
|
std::ifstream ifs(cache_file, std::ios::binary);
|
||||||
|
|
|
||||||
|
|
@ -4613,6 +4613,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
|
// TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985
|
||||||
|
if (ggml_get_op_params_i32(op, 8) != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
|
|
|
||||||
|
|
@ -353,10 +353,17 @@ enum vk_conv_shapes {
|
||||||
CONV_SHAPE_COUNT,
|
CONV_SHAPE_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
uint32_t conv_shapes_wg_denoms[][3] = {
|
struct vk_conv_block_size {
|
||||||
{ 128, 128, 1 },
|
uint32_t K;
|
||||||
{ 64, 32, 1 },
|
uint32_t NPQ;
|
||||||
{ 32, 256, 1 },
|
uint32_t CRS;
|
||||||
|
};
|
||||||
|
|
||||||
|
vk_conv_block_size vk_conv_block_sizes[CONV_SHAPE_COUNT] = {
|
||||||
|
// K NPQ CRS
|
||||||
|
{ 128, 128, 16 }, // CONV_SHAPE_128x128
|
||||||
|
{ 64, 32, 32 }, // CONV_SHAPE_64x32
|
||||||
|
{ 32, 256, 16 }, // CONV_SHAPE_32x256
|
||||||
};
|
};
|
||||||
|
|
||||||
enum dmmv_wg_sizes {
|
enum dmmv_wg_sizes {
|
||||||
|
|
@ -519,6 +526,7 @@ struct vk_device_struct {
|
||||||
bool fp16;
|
bool fp16;
|
||||||
bool bf16;
|
bool bf16;
|
||||||
bool pipeline_robustness;
|
bool pipeline_robustness;
|
||||||
|
bool memory_priority;
|
||||||
vk::Device device;
|
vk::Device device;
|
||||||
uint32_t vendor_id;
|
uint32_t vendor_id;
|
||||||
vk::DriverId driver_id;
|
vk::DriverId driver_id;
|
||||||
|
|
@ -1042,6 +1050,7 @@ struct vk_op_pad_push_constants {
|
||||||
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
||||||
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
|
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
|
||||||
uint32_t misalign_offsets;
|
uint32_t misalign_offsets;
|
||||||
|
uint32_t circular;
|
||||||
|
|
||||||
uint32_t lp0; uint32_t rp0;
|
uint32_t lp0; uint32_t rp0;
|
||||||
uint32_t lp1; uint32_t rp1;
|
uint32_t lp1; uint32_t rp1;
|
||||||
|
|
@ -1084,6 +1093,7 @@ static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor
|
||||||
p.rp2 = dst->op_params[5];
|
p.rp2 = dst->op_params[5];
|
||||||
p.lp3 = dst->op_params[6];
|
p.lp3 = dst->op_params[6];
|
||||||
p.rp3 = dst->op_params[7];
|
p.rp3 = dst->op_params[7];
|
||||||
|
p.circular = dst->op_params[8];
|
||||||
|
|
||||||
return p; // fastdiv values and offsets are initialized later in ggml_vk_op
|
return p; // fastdiv values and offsets are initialized later in ggml_vk_op
|
||||||
}
|
}
|
||||||
|
|
@ -1343,20 +1353,11 @@ struct vk_op_conv2d_push_constants {
|
||||||
uint32_t Cin;
|
uint32_t Cin;
|
||||||
uint32_t N;
|
uint32_t N;
|
||||||
|
|
||||||
uint32_t KW;
|
|
||||||
uint32_t KH;
|
|
||||||
uint32_t W;
|
uint32_t W;
|
||||||
uint32_t H;
|
uint32_t H;
|
||||||
uint32_t OW;
|
uint32_t OW;
|
||||||
uint32_t OH;
|
uint32_t OH;
|
||||||
|
|
||||||
uint32_t s0;
|
|
||||||
uint32_t s1;
|
|
||||||
uint32_t p0;
|
|
||||||
uint32_t p1;
|
|
||||||
uint32_t d0;
|
|
||||||
uint32_t d1;
|
|
||||||
|
|
||||||
uint32_t nb01;
|
uint32_t nb01;
|
||||||
uint32_t nb02;
|
uint32_t nb02;
|
||||||
uint32_t nb03;
|
uint32_t nb03;
|
||||||
|
|
@ -1380,48 +1381,6 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
|
||||||
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
|
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct vk_op_conv_transpose_2d_push_constants {
|
|
||||||
uint32_t Cout;
|
|
||||||
uint32_t Cin;
|
|
||||||
uint32_t N;
|
|
||||||
|
|
||||||
uint32_t KW;
|
|
||||||
uint32_t KH;
|
|
||||||
uint32_t W;
|
|
||||||
uint32_t H;
|
|
||||||
uint32_t OW;
|
|
||||||
uint32_t OH;
|
|
||||||
|
|
||||||
uint32_t s0;
|
|
||||||
uint32_t s1;
|
|
||||||
uint32_t p0;
|
|
||||||
uint32_t p1;
|
|
||||||
uint32_t d0;
|
|
||||||
uint32_t d1;
|
|
||||||
|
|
||||||
uint32_t nb01;
|
|
||||||
uint32_t nb02;
|
|
||||||
uint32_t nb03;
|
|
||||||
|
|
||||||
uint32_t nb11;
|
|
||||||
uint32_t nb12;
|
|
||||||
uint32_t nb13;
|
|
||||||
|
|
||||||
uint32_t nb1;
|
|
||||||
uint32_t nb2;
|
|
||||||
uint32_t nb3;
|
|
||||||
|
|
||||||
// init_fastdiv_values constants for dividing by OW, OW*OH
|
|
||||||
uint32_t OWmp; uint32_t OWL;
|
|
||||||
uint32_t OWOHmp; uint32_t OWOHL;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <> void init_pushconst_fastdiv(vk_op_conv_transpose_2d_push_constants &p) {
|
|
||||||
// Compute magic values to divide by OW, OW*OH
|
|
||||||
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
|
|
||||||
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct vk_op_conv2d_dw_push_constants {
|
struct vk_op_conv2d_dw_push_constants {
|
||||||
uint32_t ne;
|
uint32_t ne;
|
||||||
uint32_t batches;
|
uint32_t batches;
|
||||||
|
|
@ -2369,7 +2328,13 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
||||||
|
|
||||||
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
|
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
|
||||||
|
|
||||||
const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
|
const vk::MemoryPriorityAllocateInfoEXT mem_priority_info { 1.0f };
|
||||||
|
|
||||||
|
vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
|
||||||
|
|
||||||
|
if (device->memory_priority) {
|
||||||
|
mem_flags_info.setPNext(&mem_priority_info);
|
||||||
|
}
|
||||||
|
|
||||||
for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
|
for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
|
||||||
const auto & req_flags = *it;
|
const auto & req_flags = *it;
|
||||||
|
|
@ -3574,7 +3539,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
SHADER_REDUCTION_MODE_SHMEM;
|
SHADER_REDUCTION_MODE_SHMEM;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
|
for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
|
||||||
|
|
@ -3598,7 +3563,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
|
||||||
|
|
@ -3644,7 +3609,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
|
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {wg_size_subgroup, 1}, 1, false, use_subgroups, force_subgroup_size);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
|
||||||
|
|
@ -4050,7 +4015,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
|
uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
|
||||||
sizeof(int) * device->subgroup_size +
|
sizeof(int) * device->subgroup_size +
|
||||||
2 * sizeof(int) +
|
2 * sizeof(int) +
|
||||||
(BLOCK_SIZE / device->subgroup_size) * sizeof(int);
|
2 * (BLOCK_SIZE / device->subgroup_size) * sizeof(int);
|
||||||
if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
|
if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
|
||||||
nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
|
nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
|
||||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
|
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
|
||||||
|
|
@ -4070,10 +4035,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
|
|
||||||
for (auto &s : device->pipeline_solve_tri_f32) {
|
for (auto &s : device->pipeline_solve_tri_f32) {
|
||||||
const vk_solve_tri_pipeline_state &state = s.first;
|
const vk_solve_tri_pipeline_state &state = s.first;
|
||||||
|
|
||||||
|
// Max number of rows to load at a time, limited by shared memory
|
||||||
|
const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((state.N + state.K) * sizeof(float));
|
||||||
|
// Need at least K invocations, and prefer a minimum of 128 to spread out loading shared memory
|
||||||
|
const uint32_t block_size = std::max(128u, 1u << (uint32_t)ceilf(log2f(float(state.K))));
|
||||||
|
|
||||||
ggml_vk_create_pipeline(
|
ggml_vk_create_pipeline(
|
||||||
device, s.second, "solve_tri_f32",
|
device, s.second, "solve_tri_f32",
|
||||||
solve_tri_f32_len, solve_tri_f32_data, "main", 3,
|
solve_tri_f32_len, solve_tri_f32_data, "main", 3,
|
||||||
sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K }, 1, true);
|
sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K, batch_N, block_size }, 1, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define IM2COL(bda) \
|
#define IM2COL(bda) \
|
||||||
|
|
@ -4119,12 +4090,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
// conv2d, conv_transpose_2d
|
// conv2d, conv_transpose_2d
|
||||||
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
|
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
|
||||||
uint32_t conv2d_WG_SIZE = 256;
|
uint32_t conv2d_WG_SIZE = 256;
|
||||||
uint32_t conv2d_BS_K = 128;
|
|
||||||
uint32_t conv2d_BS_CRS = 16;
|
|
||||||
uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
|
uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
|
||||||
uint32_t conv2d_BS_NPQ = 128;
|
uint32_t conv2d_TS_K = (s == CONV_SHAPE_64x32) ? 4 : 8;
|
||||||
uint32_t conv2d_TS_K = 8;
|
|
||||||
uint32_t conv2d_SHMEM_PAD = 4;
|
uint32_t conv2d_SHMEM_PAD = 4;
|
||||||
|
vk_conv_block_size conv2d_BS = vk_conv_block_sizes[s];
|
||||||
bool conv2d_UNROLL = true;
|
bool conv2d_UNROLL = true;
|
||||||
|
|
||||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||||
|
|
@ -4138,29 +4107,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
conv2d_UNROLL = false;
|
conv2d_UNROLL = false;
|
||||||
} else if (device->vendor_id == VK_VENDOR_ID_AMD) {
|
} else if (device->vendor_id == VK_VENDOR_ID_AMD) {
|
||||||
conv2d_SHMEM_PAD = device->architecture == vk_device_architecture::AMD_GCN ? 1 : 4;
|
conv2d_SHMEM_PAD = device->architecture == vk_device_architecture::AMD_GCN ? 1 : 4;
|
||||||
}
|
if (s == CONV_SHAPE_128x128 && device->architecture != vk_device_architecture::AMD_GCN) {
|
||||||
|
|
||||||
switch (s) {
|
|
||||||
default:
|
|
||||||
case CONV_SHAPE_128x128:
|
|
||||||
conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_128x128][0];
|
|
||||||
conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_128x128][1];
|
|
||||||
conv2d_BS_CRS = 16;
|
|
||||||
if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != vk_device_architecture::AMD_GCN) {
|
|
||||||
conv2d_UNROLL = false;
|
conv2d_UNROLL = false;
|
||||||
}
|
}
|
||||||
break;
|
|
||||||
case CONV_SHAPE_64x32:
|
|
||||||
conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_64x32][0];
|
|
||||||
conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_64x32][1];
|
|
||||||
conv2d_BS_CRS = 32;
|
|
||||||
conv2d_TS_K = 4;
|
|
||||||
break;
|
|
||||||
case CONV_SHAPE_32x256:
|
|
||||||
conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_32x256][0];
|
|
||||||
conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_32x256][1];
|
|
||||||
conv2d_BS_CRS = 16;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use collectives on pre-Turing NVIDIA GPUs and GCN AMD cards, which had slower integer math.
|
// Use collectives on pre-Turing NVIDIA GPUs and GCN AMD cards, which had slower integer math.
|
||||||
|
|
@ -4174,22 +4123,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
allow_collectives_nv &&
|
allow_collectives_nv &&
|
||||||
allow_collectives_amd) {
|
allow_collectives_amd) {
|
||||||
use_collectives = 1;
|
use_collectives = 1;
|
||||||
conv2d_BS_CRS = std::min(
|
conv2d_BS.CRS = std::min(
|
||||||
device->subgroup_size,
|
device->subgroup_size,
|
||||||
conv2d_BS_CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used.
|
conv2d_BS.CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used.
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t conv2d_shmem_req =
|
uint32_t conv2d_shmem_req =
|
||||||
(conv2d_BS_K * (conv2d_BS_CRS + conv2d_SHMEM_PAD) + conv2d_BS_CRS * (conv2d_BS_NPQ + conv2d_SHMEM_PAD)) * sizeof(float);
|
(conv2d_BS.K * (conv2d_BS.CRS + conv2d_SHMEM_PAD) + conv2d_BS.CRS * (conv2d_BS.NPQ + conv2d_SHMEM_PAD)) * sizeof(float);
|
||||||
if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
|
if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
|
||||||
conv2d_BS_CRS = 8;
|
conv2d_BS.CRS = 8;
|
||||||
if (use_collectives) {
|
if (use_collectives) {
|
||||||
conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
|
conv2d_BS.CRS = std::min(device->subgroup_size, conv2d_BS.CRS);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::array<uint32_t, 3> wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 };
|
std::array<uint32_t, 3> wg_denoms = { conv2d_BS.K, 1, 1 };
|
||||||
std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
|
std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
|
||||||
|
|
||||||
#define CREATE_CONV(name, type_suffix, spv_suffix) \
|
#define CREATE_CONV(name, type_suffix, spv_suffix) \
|
||||||
for (auto &c : device->pipeline_##name##type_suffix[s]) { \
|
for (auto &c : device->pipeline_##name##type_suffix[s]) { \
|
||||||
|
|
@ -4206,15 +4155,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline( \
|
ggml_vk_create_pipeline( \
|
||||||
device, c.second, #name #type_suffix, \
|
device, c.second, #name #type_suffix, \
|
||||||
name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
|
name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
|
||||||
sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives); \
|
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives); \
|
||||||
}
|
}
|
||||||
#define CREATE_CONVS(spv_suffix) \
|
#define CREATE_CONVS(spv_suffix) \
|
||||||
CREATE_CONV(conv2d, _f32, spv_suffix) \
|
CREATE_CONV(conv2d, _f32, spv_suffix) \
|
||||||
CREATE_CONV(conv2d, _f16_f32, spv_suffix) \
|
CREATE_CONV(conv2d, _f16_f32, spv_suffix) \
|
||||||
if (device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_conv_transpose_2d_push_constants)) { \
|
|
||||||
CREATE_CONV(conv_transpose_2d, _f32, spv_suffix) \
|
CREATE_CONV(conv_transpose_2d, _f32, spv_suffix) \
|
||||||
CREATE_CONV(conv_transpose_2d, _f16_f32, spv_suffix) \
|
CREATE_CONV(conv_transpose_2d, _f16_f32, spv_suffix)
|
||||||
}
|
|
||||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||||
if (device->coopmat2) {
|
if (device->coopmat2) {
|
||||||
CREATE_CONVS(_cm2)
|
CREATE_CONVS(_cm2)
|
||||||
|
|
@ -4235,9 +4182,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
|
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
|
||||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true);
|
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true, device->subgroup_size);
|
||||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true);
|
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true, device->subgroup_size);
|
||||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true);
|
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true, device->subgroup_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &c : compiles) {
|
for (auto &c : compiles) {
|
||||||
|
|
@ -4340,6 +4287,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
#endif
|
#endif
|
||||||
} else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) {
|
} else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) {
|
||||||
pipeline_executable_properties_support = true;
|
pipeline_executable_properties_support = true;
|
||||||
|
} else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 &&
|
||||||
|
getenv("GGML_VK_ENABLE_MEMORY_PRIORITY")) {
|
||||||
|
device->memory_priority = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -4531,6 +4481,16 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
device_extensions.push_back("VK_EXT_pipeline_robustness");
|
device_extensions.push_back("VK_EXT_pipeline_robustness");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
VkPhysicalDeviceMemoryPriorityFeaturesEXT memory_priority_features;
|
||||||
|
memory_priority_features.pNext = nullptr;
|
||||||
|
memory_priority_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PRIORITY_FEATURES_EXT;
|
||||||
|
memory_priority_features.memoryPriority = VK_FALSE;
|
||||||
|
if (device->memory_priority) {
|
||||||
|
last_struct->pNext = (VkBaseOutStructure *)&memory_priority_features;
|
||||||
|
last_struct = (VkBaseOutStructure *)&memory_priority_features;
|
||||||
|
device_extensions.push_back("VK_EXT_memory_priority");
|
||||||
|
}
|
||||||
|
|
||||||
VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features;
|
VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features;
|
||||||
subgroup_size_control_features.pNext = nullptr;
|
subgroup_size_control_features.pNext = nullptr;
|
||||||
subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT;
|
subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT;
|
||||||
|
|
@ -5110,7 +5070,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_vk_instance_validation_ext_available();
|
static bool ggml_vk_instance_layer_settings_available();
|
||||||
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
|
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
|
||||||
static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);
|
static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);
|
||||||
static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev);
|
static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev);
|
||||||
|
|
@ -5139,19 +5099,19 @@ static void ggml_vk_instance_init() {
|
||||||
vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version };
|
vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version };
|
||||||
|
|
||||||
const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
|
const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
|
||||||
const bool validation_ext = ggml_vk_instance_validation_ext_available();
|
const bool layer_settings = ggml_vk_instance_layer_settings_available();
|
||||||
#ifdef __APPLE__
|
#ifdef __APPLE__
|
||||||
const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
|
const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
|
||||||
#endif
|
#endif
|
||||||
const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv("GGML_VK_DEBUG_MARKERS") != nullptr;
|
const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv("GGML_VK_DEBUG_MARKERS") != nullptr;
|
||||||
std::vector<const char*> layers;
|
std::vector<const char*> layers;
|
||||||
|
|
||||||
if (validation_ext) {
|
if (layer_settings) {
|
||||||
layers.push_back("VK_LAYER_KHRONOS_validation");
|
layers.push_back("VK_LAYER_KHRONOS_validation");
|
||||||
}
|
}
|
||||||
std::vector<const char*> extensions;
|
std::vector<const char*> extensions;
|
||||||
if (validation_ext) {
|
if (layer_settings) {
|
||||||
extensions.push_back("VK_EXT_validation_features");
|
extensions.push_back("VK_EXT_layer_settings");
|
||||||
}
|
}
|
||||||
#ifdef __APPLE__
|
#ifdef __APPLE__
|
||||||
if (portability_enumeration_ext) {
|
if (portability_enumeration_ext) {
|
||||||
|
|
@ -5161,26 +5121,24 @@ static void ggml_vk_instance_init() {
|
||||||
if (debug_utils_ext) {
|
if (debug_utils_ext) {
|
||||||
extensions.push_back("VK_EXT_debug_utils");
|
extensions.push_back("VK_EXT_debug_utils");
|
||||||
}
|
}
|
||||||
vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
|
VkBool32 enable_best_practice = layer_settings;
|
||||||
|
std::vector<vk::LayerSettingEXT> settings = {
|
||||||
|
{
|
||||||
|
"VK_LAYER_KHRONOS_validation",
|
||||||
|
"validate_best_practices",
|
||||||
|
vk::LayerSettingTypeEXT::eBool32,
|
||||||
|
1,
|
||||||
|
&enable_best_practice
|
||||||
|
},
|
||||||
|
};
|
||||||
|
vk::LayerSettingsCreateInfoEXT layer_setting_info(settings);
|
||||||
|
vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions, &layer_setting_info);
|
||||||
#ifdef __APPLE__
|
#ifdef __APPLE__
|
||||||
if (portability_enumeration_ext) {
|
if (portability_enumeration_ext) {
|
||||||
instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;
|
instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
std::vector<vk::ValidationFeatureEnableEXT> features_enable;
|
|
||||||
vk::ValidationFeaturesEXT validation_features;
|
|
||||||
|
|
||||||
if (validation_ext) {
|
|
||||||
features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices };
|
|
||||||
validation_features = {
|
|
||||||
features_enable,
|
|
||||||
{},
|
|
||||||
};
|
|
||||||
validation_features.setPNext(nullptr);
|
|
||||||
instance_create_info.setPNext(&validation_features);
|
|
||||||
GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n");
|
|
||||||
}
|
|
||||||
vk_instance.instance = vk::createInstance(instance_create_info);
|
vk_instance.instance = vk::createInstance(instance_create_info);
|
||||||
vk_instance_initialized = true;
|
vk_instance_initialized = true;
|
||||||
|
|
||||||
|
|
@ -6928,6 +6886,10 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
|
||||||
// Quantization overhead is not worth it for small k
|
// Quantization overhead is not worth it for small k
|
||||||
switch (device->vendor_id) {
|
switch (device->vendor_id) {
|
||||||
case VK_VENDOR_ID_NVIDIA:
|
case VK_VENDOR_ID_NVIDIA:
|
||||||
|
if (src0_type == GGML_TYPE_Q2_K) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
if (k <= 4096) {
|
if (k <= 4096) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
@ -8260,59 +8222,23 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::array<uint32_t, 3> ggml_vk_get_conv_elements(const ggml_tensor *dst) {
|
static vk_conv_shapes ggml_vk_conv_select_shape(ggml_backend_vk_context * ctx, uint32_t K, uint32_t NPQ) {
|
||||||
const ggml_tensor *src0 = dst->src[0];
|
auto n_tiles = [&](vk_conv_shapes s) {
|
||||||
const ggml_tensor *src1 = dst->src[1];
|
return CEIL_DIV(K, vk_conv_block_sizes[s].K)
|
||||||
|
* CEIL_DIV(NPQ, vk_conv_block_sizes[s].NPQ);
|
||||||
// src0 - kernel: [KW, KH, Cin, Cout]
|
|
||||||
// src1 - input: [W, H, Cin, N]
|
|
||||||
// dst - result: [OW, OH, Cout, N]
|
|
||||||
|
|
||||||
// Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
|
|
||||||
auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
|
|
||||||
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
|
|
||||||
};
|
};
|
||||||
// parallelize in {OW/BS_K, OH/BS_NPQ, 1}
|
|
||||||
int64_t W = src1->ne[0];
|
|
||||||
int64_t H = src1->ne[1];
|
|
||||||
int64_t KW = src0->ne[0];
|
|
||||||
int64_t KH = src0->ne[1];
|
|
||||||
int64_t Cout = src0->ne[3];
|
|
||||||
int64_t N = src1->ne[3];
|
|
||||||
int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
|
|
||||||
int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
|
|
||||||
int64_t NPQ = N * OW * OH;
|
|
||||||
|
|
||||||
// Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
|
// We can't query number of shader cores on Intel, use 32 as a placeholder
|
||||||
std::array<uint32_t, 3> elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
|
// so small convolutions will still choose a smaller tile.
|
||||||
return elements;
|
const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32;
|
||||||
}
|
|
||||||
|
|
||||||
static std::array<uint32_t, 3> ggml_vk_get_conv_transpose_2d_elements(const ggml_tensor *dst) {
|
if (K > 64 && n_tiles(CONV_SHAPE_128x128) >= shader_core_count * 2) {
|
||||||
const ggml_tensor *src0 = dst->src[0];
|
return CONV_SHAPE_128x128;
|
||||||
const ggml_tensor *src1 = dst->src[1];
|
} else if (K <= 32 && n_tiles(CONV_SHAPE_32x256) >= shader_core_count * 2) {
|
||||||
|
return CONV_SHAPE_32x256;
|
||||||
// src0 - kernel: [KW, KH, Cout, Cin]
|
} else {
|
||||||
// src1 - input: [W, H, Cin, N]
|
return CONV_SHAPE_64x32;
|
||||||
// dst - result: [OW, OH, Cout, N]
|
}
|
||||||
|
|
||||||
auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
|
|
||||||
return (ins - 1) * s - 2 * p + (ks - 1) * d + 1;
|
|
||||||
};
|
|
||||||
// parallelize in {OW/BS_K, OH/BS_NPQ, 1}
|
|
||||||
int64_t W = src1->ne[0];
|
|
||||||
int64_t H = src1->ne[1];
|
|
||||||
int64_t KW = src0->ne[0];
|
|
||||||
int64_t KH = src0->ne[1];
|
|
||||||
int64_t Cout = src0->ne[2];
|
|
||||||
int64_t N = src1->ne[3];
|
|
||||||
int64_t OH = calc_conv_output_size(H, KH, dst->op_params[0], 0, 1);
|
|
||||||
int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], 0, 1);
|
|
||||||
int64_t NPQ = N * OW * OH;
|
|
||||||
|
|
||||||
// Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
|
|
||||||
std::array<uint32_t, 3> elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
|
|
||||||
return elements;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) {
|
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) {
|
||||||
|
|
@ -8775,39 +8701,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return nullptr;
|
return nullptr;
|
||||||
case GGML_OP_CONV_2D:
|
case GGML_OP_CONV_2D:
|
||||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||||
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
|
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
|
uint32_t K = dst->ne[2]; // Cout
|
||||||
std::array<uint32_t, 3> elements{};
|
uint32_t NPQ = dst->ne[3] * dst->ne[1] * dst->ne[0]; // N * OH * OW
|
||||||
if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst);
|
vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, K, NPQ);
|
||||||
else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst);
|
|
||||||
vk_conv_shapes shape;
|
|
||||||
|
|
||||||
uint32_t tiles[CONV_SHAPE_COUNT];
|
|
||||||
for (uint32_t i = 0; i < CONV_SHAPE_COUNT; ++i) {
|
|
||||||
tiles[i] = CEIL_DIV(elements[0], conv_shapes_wg_denoms[i][0]) * CEIL_DIV(elements[1], conv_shapes_wg_denoms[i][1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// We can't query number of shader cores on Intel, use 32 as a placeholder
|
|
||||||
// so small convolutions will still choose a smaller tile.
|
|
||||||
const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32;
|
|
||||||
|
|
||||||
if (elements[0] > 64 && tiles[CONV_SHAPE_128x128] >= shader_core_count * 2) {
|
|
||||||
shape = CONV_SHAPE_128x128;
|
|
||||||
} else if (elements[0] <= 32 && tiles[CONV_SHAPE_32x256] >= shader_core_count * 2) {
|
|
||||||
shape = CONV_SHAPE_32x256;
|
|
||||||
} else {
|
|
||||||
shape = CONV_SHAPE_64x32;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t KW = static_cast<uint32_t>(src0->ne[0]);
|
|
||||||
uint32_t KH = static_cast<uint32_t>(src0->ne[1]);
|
|
||||||
uint32_t s0 = static_cast<uint32_t>(dst->op_params[0]);
|
|
||||||
uint32_t s1 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[1]) : static_cast<uint32_t>(dst->op_params[0]);
|
|
||||||
uint32_t p0 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[2]) : 0;
|
|
||||||
uint32_t p1 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[3]) : 0;
|
|
||||||
uint32_t d0 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[4]) : 1;
|
|
||||||
uint32_t d1 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[5]) : 1;
|
|
||||||
|
|
||||||
|
bool transpose = dst->op == GGML_OP_CONV_TRANSPOSE_2D;
|
||||||
|
uint32_t KW = (uint32_t)src0->ne[0];
|
||||||
|
uint32_t KH = (uint32_t)src0->ne[1];
|
||||||
|
uint32_t s0 = (uint32_t)(ggml_get_op_params_i32(dst, 0));
|
||||||
|
uint32_t s1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 1) : s0;
|
||||||
|
uint32_t p0 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 2) : 0;
|
||||||
|
uint32_t p1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 3) : 0;
|
||||||
|
uint32_t d0 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 4) : 1;
|
||||||
|
uint32_t d1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 5) : 1;
|
||||||
vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH);
|
vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH);
|
||||||
|
|
||||||
std::map<vk_conv2d_pipeline_state, vk_pipeline> *pipelines = nullptr;
|
std::map<vk_conv2d_pipeline_state, vk_pipeline> *pipelines = nullptr;
|
||||||
|
|
@ -9126,13 +9033,21 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
elements = { N * OC * OH * OW, 1, 1};
|
elements = { N * OC * OH * OW, 1, 1};
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_CONV_2D:
|
case GGML_OP_CONV_2D:
|
||||||
{
|
|
||||||
elements = ggml_vk_get_conv_elements(dst);
|
|
||||||
} break;
|
|
||||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||||
{
|
if constexpr (std::is_same_v<PC, vk_op_conv2d_push_constants>) {
|
||||||
elements = ggml_vk_get_conv_transpose_2d_elements(dst);
|
const uint32_t NPQ = pc.N * pc.OH * pc.OW;
|
||||||
} break;
|
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, pc.Cout, NPQ);
|
||||||
|
const uint32_t NPQ_blocks = CEIL_DIV(NPQ, vk_conv_block_sizes[shape].NPQ);
|
||||||
|
|
||||||
|
elements = { pc.Cout, NPQ_blocks, 1 };
|
||||||
|
if (elements[1] > 512) {
|
||||||
|
elements[2] = CEIL_DIV(elements[1], 512);
|
||||||
|
elements[1] = 512;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("invalid push constant type for CONV_2D");
|
||||||
|
}
|
||||||
|
break;
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_SUB:
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
|
|
@ -10683,30 +10598,24 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
GGML_TENSOR_BINARY_OP_LOCALS
|
GGML_TENSOR_BINARY_OP_LOCALS
|
||||||
|
|
||||||
GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
|
GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
|
||||||
GGML_ASSERT(nb10 == sizeof(float));
|
GGML_ASSERT(nb10 == sizeof(float));
|
||||||
GGML_ASSERT(nb0 == sizeof(float));
|
GGML_ASSERT(nb0 == sizeof(float));
|
||||||
|
|
||||||
|
bool transpose = dst->op == GGML_OP_CONV_TRANSPOSE_2D;
|
||||||
|
|
||||||
vk_op_conv2d_push_constants p{};
|
vk_op_conv2d_push_constants p{};
|
||||||
p.Cout = static_cast<uint32_t>(ne03);
|
p.Cout = static_cast<uint32_t>(!transpose ? ne03 : ne02);
|
||||||
p.Cin = static_cast<uint32_t>(ne02);
|
p.Cin = static_cast<uint32_t>(!transpose ? ne02 : ne03);
|
||||||
p.N = static_cast<uint32_t>(ne13);
|
p.N = static_cast<uint32_t>(ne13);
|
||||||
|
GGML_ASSERT(p.Cout == ne2);
|
||||||
|
GGML_ASSERT(p.Cin == ne12);
|
||||||
|
|
||||||
p.KW = static_cast<uint32_t>(ne00);
|
|
||||||
p.KH = static_cast<uint32_t>(ne01);
|
|
||||||
p.W = static_cast<uint32_t>(ne10);
|
p.W = static_cast<uint32_t>(ne10);
|
||||||
p.H = static_cast<uint32_t>(ne11);
|
p.H = static_cast<uint32_t>(ne11);
|
||||||
p.OW = static_cast<uint32_t>(ne0);
|
p.OW = static_cast<uint32_t>(ne0);
|
||||||
p.OH = static_cast<uint32_t>(ne1);
|
p.OH = static_cast<uint32_t>(ne1);
|
||||||
|
|
||||||
p.s0 = static_cast<uint32_t>(dst->op_params[0]);
|
|
||||||
p.s1 = static_cast<uint32_t>(dst->op_params[1]);
|
|
||||||
p.p0 = static_cast<uint32_t>(dst->op_params[2]);
|
|
||||||
p.p1 = static_cast<uint32_t>(dst->op_params[3]);
|
|
||||||
p.d0 = static_cast<uint32_t>(dst->op_params[4]);
|
|
||||||
p.d1 = static_cast<uint32_t>(dst->op_params[5]);
|
|
||||||
|
|
||||||
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
|
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
|
||||||
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
|
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
|
||||||
p.nb03 = static_cast<uint32_t>(nb03 / nb00);
|
p.nb03 = static_cast<uint32_t>(nb03 / nb00);
|
||||||
|
|
@ -10719,59 +10628,7 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
|
||||||
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
|
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
|
||||||
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
|
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
|
||||||
|
|
||||||
GGML_ASSERT(ne03 == ne2);
|
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, std::move(p));
|
||||||
GGML_ASSERT(ne02 == ne12);
|
|
||||||
|
|
||||||
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_2D, std::move(p));
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
|
|
||||||
const ggml_tensor * src1, ggml_tensor * dst) {
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
||||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
||||||
|
|
||||||
GGML_TENSOR_BINARY_OP_LOCALS
|
|
||||||
|
|
||||||
GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
|
|
||||||
GGML_ASSERT(nb10 == sizeof(float));
|
|
||||||
GGML_ASSERT(nb0 == sizeof(float));
|
|
||||||
|
|
||||||
vk_op_conv_transpose_2d_push_constants p{};
|
|
||||||
p.Cout = static_cast<uint32_t>(ne02);
|
|
||||||
p.Cin = static_cast<uint32_t>(ne03);
|
|
||||||
p.N = static_cast<uint32_t>(ne13);
|
|
||||||
|
|
||||||
p.KW = static_cast<uint32_t>(ne00);
|
|
||||||
p.KH = static_cast<uint32_t>(ne01);
|
|
||||||
p.W = static_cast<uint32_t>(ne10);
|
|
||||||
p.H = static_cast<uint32_t>(ne11);
|
|
||||||
p.OW = static_cast<uint32_t>(ne0);
|
|
||||||
p.OH = static_cast<uint32_t>(ne1);
|
|
||||||
|
|
||||||
p.s0 = static_cast<uint32_t>(dst->op_params[0]);
|
|
||||||
p.s1 = static_cast<uint32_t>(dst->op_params[0]);
|
|
||||||
p.p0 = 0;
|
|
||||||
p.p1 = 0;
|
|
||||||
p.d0 = 1;
|
|
||||||
p.d1 = 1;
|
|
||||||
|
|
||||||
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
|
|
||||||
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
|
|
||||||
p.nb03 = static_cast<uint32_t>(nb03 / nb00);
|
|
||||||
|
|
||||||
p.nb11 = static_cast<uint32_t>(nb11 / nb10);
|
|
||||||
p.nb12 = static_cast<uint32_t>(nb12 / nb10);
|
|
||||||
p.nb13 = static_cast<uint32_t>(nb13 / nb10);
|
|
||||||
|
|
||||||
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
|
|
||||||
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
|
|
||||||
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
|
|
||||||
|
|
||||||
GGML_ASSERT(ne02 == ne2);
|
|
||||||
GGML_ASSERT(ne03 == ne12);
|
|
||||||
|
|
||||||
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
|
@ -12142,11 +11999,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_CONV_2D:
|
case GGML_OP_CONV_2D:
|
||||||
ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node);
|
|
||||||
|
|
||||||
break;
|
|
||||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||||
ggml_vk_conv_transpose_2d(ctx, compute_ctx, src0, src1, node);
|
ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_CONV_2D_DW:
|
case GGML_OP_CONV_2D_DW:
|
||||||
|
|
@ -14179,10 +14033,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
const uint32_t N = op->src[0]->ne[0];
|
const uint32_t N = op->src[0]->ne[0];
|
||||||
const uint32_t K = op->src[1]->ne[0];
|
const uint32_t K = op->src[1]->ne[0];
|
||||||
// K dimension limited to workgroup size
|
// K dimension limited to workgroup size
|
||||||
if (K > 128) {
|
if (K > 1u << device->max_workgroup_size_log2) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (N * N * sizeof(float) + N * K * sizeof(float) > device->properties.limits.maxComputeSharedMemorySize) {
|
const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((N + K) * sizeof(float));
|
||||||
|
|
||||||
|
if (batch_N == 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
|
@ -14255,13 +14111,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_OP_CONV_2D:
|
case GGML_OP_CONV_2D:
|
||||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||||
{
|
{
|
||||||
// Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK
|
|
||||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
||||||
const vk_device& device = ggml_vk_get_device(ctx->device);
|
|
||||||
if (op->op == GGML_OP_CONV_TRANSPOSE_2D &&
|
|
||||||
device->properties.limits.maxPushConstantsSize < sizeof(vk_op_conv_transpose_2d_push_constants)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// Channel-contiguous format is not supported yet.
|
// Channel-contiguous format is not supported yet.
|
||||||
return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||||
op->src[1]->type == GGML_TYPE_F32 &&
|
op->src[1]->type == GGML_TYPE_F32 &&
|
||||||
|
|
@ -14386,21 +14235,21 @@ ggml_backend_reg_t ggml_backend_vk_reg() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extension availability
|
// Extension availability
|
||||||
static bool ggml_vk_instance_validation_ext_available() {
|
static bool ggml_vk_instance_layer_settings_available() {
|
||||||
#ifdef GGML_VULKAN_VALIDATE
|
#ifdef GGML_VULKAN_VALIDATE
|
||||||
// Check if validation layer provides the extension
|
// Check if validation layer provides the extension
|
||||||
const std::string layer_name = "VK_LAYER_KHRONOS_validation";
|
const std::string layer_name = "VK_LAYER_KHRONOS_validation";
|
||||||
for (const auto& layer : vk::enumerateInstanceLayerProperties()) {
|
for (const auto& layer : vk::enumerateInstanceLayerProperties()) {
|
||||||
if (layer_name == layer.layerName.data()) {
|
if (layer_name == layer.layerName.data()) {
|
||||||
for (const auto& ext : vk::enumerateInstanceExtensionProperties(layer_name)) {
|
for (const auto& ext : vk::enumerateInstanceExtensionProperties(layer_name)) {
|
||||||
if (strcmp("VK_EXT_validation_features", ext.extensionName.data()) == 0) {
|
if (strcmp("VK_EXT_layer_settings", ext.extensionName.data()) == 0) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::cerr << "ggml_vulkan: WARNING: Validation layer or layer extension VK_EXT_validation_features not found." << std::endl;
|
std::cerr << "ggml_vulkan: WARNING: Validation layer or layer extension VK_EXT_layer_settings not found." << std::endl;
|
||||||
#endif
|
#endif
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -32,22 +32,12 @@ layout(push_constant) uniform parameter {
|
||||||
uint32_t Cin;
|
uint32_t Cin;
|
||||||
uint32_t N;
|
uint32_t N;
|
||||||
|
|
||||||
// Tensor spatial sizes: kernel, input, output
|
// Tensor spatial sizes: input, output
|
||||||
uint32_t KW;
|
|
||||||
uint32_t KH;
|
|
||||||
uint32_t W;
|
uint32_t W;
|
||||||
uint32_t H;
|
uint32_t H;
|
||||||
uint32_t OW;
|
uint32_t OW;
|
||||||
uint32_t OH;
|
uint32_t OH;
|
||||||
|
|
||||||
// Parameters: stride, padding, dilation - 0=y, 1=x
|
|
||||||
uint32_t s0;
|
|
||||||
uint32_t s1;
|
|
||||||
uint32_t p0;
|
|
||||||
uint32_t p1;
|
|
||||||
uint32_t d0;
|
|
||||||
uint32_t d1;
|
|
||||||
|
|
||||||
// Strides in elements
|
// Strides in elements
|
||||||
uint32_t nb01;
|
uint32_t nb01;
|
||||||
uint32_t nb02;
|
uint32_t nb02;
|
||||||
|
|
@ -77,13 +67,14 @@ layout(constant_id = 3) const uint BS_NPQ = 128;
|
||||||
layout(constant_id = 4) const uint TS_K = 8;
|
layout(constant_id = 4) const uint TS_K = 8;
|
||||||
layout(constant_id = 5) const uint use_collectives = 1;
|
layout(constant_id = 5) const uint use_collectives = 1;
|
||||||
layout(constant_id = 6) const uint SHMEM_PAD = 4;
|
layout(constant_id = 6) const uint SHMEM_PAD = 4;
|
||||||
|
// Stride, padding, dilation
|
||||||
layout(constant_id = 7) const uint s0 = 1;
|
layout(constant_id = 7) const uint s0 = 1;
|
||||||
layout(constant_id = 8) const uint s1 = 1;
|
layout(constant_id = 8) const uint s1 = 1;
|
||||||
layout(constant_id = 9) const uint p0 = 0;
|
layout(constant_id = 9) const uint p0 = 0;
|
||||||
layout(constant_id = 10) const uint p1 = 0;
|
layout(constant_id = 10) const uint p1 = 0;
|
||||||
layout(constant_id = 11) const uint d0 = 1;
|
layout(constant_id = 11) const uint d0 = 1;
|
||||||
layout(constant_id = 12) const uint d1 = 1;
|
layout(constant_id = 12) const uint d1 = 1;
|
||||||
|
// Kernel spatial sizes
|
||||||
layout(constant_id = 13) const uint KW = 1;
|
layout(constant_id = 13) const uint KW = 1;
|
||||||
layout(constant_id = 14) const uint KH = 1;
|
layout(constant_id = 14) const uint KH = 1;
|
||||||
|
|
||||||
|
|
@ -138,7 +129,7 @@ P,Q=OH,OW
|
||||||
*/
|
*/
|
||||||
|
|
||||||
uint32_t B_idx_K = gl_WorkGroupID.x;
|
uint32_t B_idx_K = gl_WorkGroupID.x;
|
||||||
uint32_t B_idx_NPQ = gl_WorkGroupID.y;
|
uint32_t B_idx_NPQ = gl_WorkGroupID.y + gl_WorkGroupID.z * 512;
|
||||||
|
|
||||||
uint32_t T_y = tid / NT_NPQ;
|
uint32_t T_y = tid / NT_NPQ;
|
||||||
uint32_t T_x = tid % NT_NPQ;
|
uint32_t T_x = tid % NT_NPQ;
|
||||||
|
|
@ -178,6 +169,10 @@ ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_T
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
|
if (B_idx_NPQ * BS_NPQ >= NPQ) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef COOPMAT2
|
#ifdef COOPMAT2
|
||||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
|
||||||
matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);
|
matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ layout (push_constant) uniform parameter
|
||||||
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
|
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
|
||||||
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
|
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
|
||||||
uint misalign_offsets;
|
uint misalign_offsets;
|
||||||
|
uint circular;
|
||||||
|
|
||||||
uint lp0; uint rp0;
|
uint lp0; uint rp0;
|
||||||
uint lp1; uint rp1;
|
uint lp1; uint rp1;
|
||||||
|
|
@ -18,6 +19,10 @@ layout (push_constant) uniform parameter
|
||||||
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||||
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
||||||
|
|
||||||
|
uint wrap_around(int coord, uint size) {
|
||||||
|
return (uint(coord + int(size))) % size; // add size to avoid issues with negative
|
||||||
|
}
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
|
@ -40,10 +45,20 @@ void main() {
|
||||||
const uint src0_idx = (i3 - p.lp3)*p.nb03 + (i2 - p.lp2)*p.nb02 + (i1 - p.lp1)*p.nb01 + (i0 - p.lp0)*p.nb00;
|
const uint src0_idx = (i3 - p.lp3)*p.nb03 + (i2 - p.lp2)*p.nb02 + (i1 - p.lp1)*p.nb01 + (i0 - p.lp0)*p.nb00;
|
||||||
const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10;
|
const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10;
|
||||||
|
|
||||||
|
if (p.circular != 0u) {
|
||||||
|
const uint ci0 = wrap_around(int(i0) - int(p.lp0), p.ne00);
|
||||||
|
const uint ci1 = wrap_around(int(i1) - int(p.lp1), p.ne01);
|
||||||
|
const uint ci2 = wrap_around(int(i2) - int(p.lp2), p.ne02);
|
||||||
|
const uint ci3 = wrap_around(int(i3) - int(p.lp3), p.ne03);
|
||||||
|
const uint circular_src_idx = ci3*p.nb03 + ci2*p.nb02 + ci1*p.nb01 + ci0*p.nb00;
|
||||||
|
data_d[get_doffset() + dst_idx] = D_TYPE(data_a[get_aoffset() + circular_src_idx]);
|
||||||
|
} else {
|
||||||
const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 &&
|
const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 &&
|
||||||
i1 >= p.lp1 && i1 < p.ne11 - p.rp1 &&
|
i1 >= p.lp1 && i1 < p.ne11 - p.rp1 &&
|
||||||
i2 >= p.lp2 && i2 < p.ne12 - p.rp2 &&
|
i2 >= p.lp2 && i2 < p.ne12 - p.rp2 &&
|
||||||
i3 >= p.lp3 && i3 < p.ne13 - p.rp3;
|
i3 >= p.lp3 && i3 < p.ne13 - p.rp3;
|
||||||
|
|
||||||
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);
|
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -131,8 +131,12 @@ void main() {
|
||||||
rms_norm(num_blocks);
|
rms_norm(num_blocks);
|
||||||
} else if (num_blocks > 16) {
|
} else if (num_blocks > 16) {
|
||||||
rms_norm(32);
|
rms_norm(32);
|
||||||
} else if (num_blocks > 8) {
|
} else if (num_blocks > 12) {
|
||||||
rms_norm(16);
|
rms_norm(16);
|
||||||
|
} else if (num_blocks > 10) {
|
||||||
|
rms_norm(12);
|
||||||
|
} else if (num_blocks > 8) {
|
||||||
|
rms_norm(10);
|
||||||
} else if (num_blocks > 4) {
|
} else if (num_blocks > 4) {
|
||||||
rms_norm(8);
|
rms_norm(8);
|
||||||
} else if (num_blocks == 4) {
|
} else if (num_blocks == 4) {
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,9 @@
|
||||||
|
|
||||||
layout (constant_id = 1) const uint N = 64;
|
layout (constant_id = 1) const uint N = 64;
|
||||||
layout (constant_id = 2) const uint K = 32;
|
layout (constant_id = 2) const uint K = 32;
|
||||||
|
layout (constant_id = 3) const uint BATCH_N = 32;
|
||||||
|
|
||||||
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 4, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
uint a_base, b_base, x_base;
|
uint a_base, b_base, x_base;
|
||||||
|
|
||||||
|
|
@ -22,8 +23,8 @@ void store_x(uint r, uint c, FLOAT_TYPE v) {
|
||||||
data_d[x_base + r * p.nb21 + c * p.nb20] = D_TYPE(v);
|
data_d[x_base + r * p.nb21 + c * p.nb20] = D_TYPE(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
shared FLOAT_TYPE shA[N * N];
|
shared FLOAT_TYPE shA[BATCH_N * N];
|
||||||
shared FLOAT_TYPE shB[N * K];
|
shared FLOAT_TYPE shB[BATCH_N * K];
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const uint batch = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
const uint batch = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||||
|
|
@ -39,34 +40,42 @@ void main() {
|
||||||
b_base = get_boffset() + i2 * p.nb12 + i3 * p.nb13;
|
b_base = get_boffset() + i2 * p.nb12 + i3 * p.nb13;
|
||||||
x_base = get_doffset() + i2 * p.nb22 + i3 * p.nb23;
|
x_base = get_doffset() + i2 * p.nb22 + i3 * p.nb23;
|
||||||
|
|
||||||
// Load the A matrix into shA
|
FLOAT_TYPE X[N];
|
||||||
[[unroll]] for (uint i = 0; i < N * N; i += gl_WorkGroupSize.x) {
|
|
||||||
|
// Loop over batches of rows
|
||||||
|
[[unroll]] for (uint row_base = 0; row_base < N; row_base += BATCH_N) {
|
||||||
|
const uint cur_N = min(BATCH_N, N - row_base);
|
||||||
|
|
||||||
|
// Load the A matrix batch into shA
|
||||||
|
[[unroll]] for (uint i = 0; i < cur_N * N; i += gl_WorkGroupSize.x) {
|
||||||
uint idx = i + tid;
|
uint idx = i + tid;
|
||||||
if (((N * N) % gl_WorkGroupSize.x == 0) || idx < N * N) {
|
if (((cur_N * N) % gl_WorkGroupSize.x == 0) || idx < cur_N * N) {
|
||||||
shA[idx] = get_a(idx / N, idx % N);
|
shA[idx] = get_a(row_base + idx / N, idx % N);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Load the B matrix into shB
|
// Load the B matrix batch into shB
|
||||||
[[unroll]] for (uint i = 0; i < N * K; i += gl_WorkGroupSize.x) {
|
[[unroll]] for (uint i = 0; i < cur_N * K; i += gl_WorkGroupSize.x) {
|
||||||
uint idx = i + tid;
|
uint idx = i + tid;
|
||||||
if (((N * K) % gl_WorkGroupSize.x == 0) || idx < N * K) {
|
if (((cur_N * K) % gl_WorkGroupSize.x == 0) || idx < cur_N * K) {
|
||||||
shB[idx] = get_b(idx / K, idx % K);
|
shB[idx] = get_b(row_base + idx / K, idx % K);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
FLOAT_TYPE X[N];
|
|
||||||
// Each thread solves one column
|
// Each thread solves one column
|
||||||
if (tid < K) {
|
if (tid < K) {
|
||||||
[[unroll]] for (int r = 0; r < N; ++r) {
|
[[unroll]] for (uint row_offset = 0; row_offset < cur_N; ++row_offset) {
|
||||||
FLOAT_TYPE b = shB[r * K + tid];
|
uint r = row_base + row_offset;
|
||||||
|
FLOAT_TYPE b = shB[row_offset * K + tid];
|
||||||
// Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r]
|
// Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r]
|
||||||
[[unroll]] for (int c = 0; c < r; ++c) {
|
[[unroll]] for (int c = 0; c < r; ++c) {
|
||||||
b -= shA[r * N + c] * X[c];
|
b -= shA[row_offset * N + c] * X[c];
|
||||||
}
|
}
|
||||||
FLOAT_TYPE x = b / shA[r * N + r];
|
FLOAT_TYPE x = b / shA[row_offset * N + r];
|
||||||
X[r] = x;
|
X[r] = x;
|
||||||
store_x(r, tid, x);
|
store_x(r, tid, x);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,7 @@ void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit
|
||||||
}
|
}
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
|
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID;
|
||||||
if (row >= n_rows) {
|
if (row >= n_rows) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
@ -83,17 +83,18 @@ void main() {
|
||||||
const uint logits_offset = n_experts * row;
|
const uint logits_offset = n_experts * row;
|
||||||
const uint weights_offset = n_expert_used * row;
|
const uint weights_offset = n_expert_used * row;
|
||||||
const uint ids_offset = n_experts * row;
|
const uint ids_offset = n_experts * row;
|
||||||
|
const uint lane = gl_SubgroupInvocationID;
|
||||||
|
|
||||||
float wt[experts_per_thread];
|
float wt[experts_per_thread];
|
||||||
|
|
||||||
[[unroll]]
|
[[unroll]]
|
||||||
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
||||||
const uint expert = i + gl_LocalInvocationID.x;
|
const uint expert = i + lane;
|
||||||
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
|
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!late_softmax) {
|
if (!late_softmax) {
|
||||||
softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false);
|
softmax_warp_inplace(wt, n_experts, lane, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// at this point, each thread holds a portion of softmax,
|
// at this point, each thread holds a portion of softmax,
|
||||||
|
|
@ -111,11 +112,11 @@ void main() {
|
||||||
|
|
||||||
for (int k = 0; k < n_expert_used; k++) {
|
for (int k = 0; k < n_expert_used; k++) {
|
||||||
float max_val = wt[0];
|
float max_val = wt[0];
|
||||||
uint max_expert = gl_LocalInvocationID.x;
|
uint max_expert = lane;
|
||||||
|
|
||||||
[[unroll]]
|
[[unroll]]
|
||||||
for (int i = 1; i < experts_per_thread; i++) {
|
for (int i = 1; i < experts_per_thread; i++) {
|
||||||
const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE;
|
const uint expert = lane + i * WARP_SIZE;
|
||||||
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
|
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
|
||||||
max_val = wt[i];
|
max_val = wt[i];
|
||||||
max_expert = expert;
|
max_expert = expert;
|
||||||
|
|
@ -132,11 +133,11 @@ void main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
|
if ((k & (WARP_SIZE - 1)) == lane) {
|
||||||
output_weights[k / WARP_SIZE] = max_val;
|
output_weights[k / WARP_SIZE] = max_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
|
if ((max_expert & (WARP_SIZE - 1)) == lane) {
|
||||||
wt[max_expert / WARP_SIZE] = -INFINITY;
|
wt[max_expert / WARP_SIZE] = -INFINITY;
|
||||||
|
|
||||||
ids[ids_offset + k] = max_expert;
|
ids[ids_offset + k] = max_expert;
|
||||||
|
|
@ -158,12 +159,12 @@ void main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (late_softmax) {
|
if (late_softmax) {
|
||||||
softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true);
|
softmax_warp_inplace(output_weights, n_expert_used, lane, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]]
|
[[unroll]]
|
||||||
for (uint i = 0; i < experts_per_thread; ++i) {
|
for (uint i = 0; i < experts_per_thread; ++i) {
|
||||||
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
|
uint idx = i * WARP_SIZE + lane;
|
||||||
if (idx < n_expert_used) {
|
if (idx < n_expert_used) {
|
||||||
weights[weights_offset + idx] = output_weights[i];
|
weights[weights_offset + idx] = output_weights[i];
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ shared int counts[SUBGROUP_SIZE];
|
||||||
shared int sh_min_idx;
|
shared int sh_min_idx;
|
||||||
shared uint sh_total;
|
shared uint sh_total;
|
||||||
shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
|
shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
|
||||||
|
shared uint eq_min_partials[BLOCK_SIZE / SUBGROUP_SIZE];
|
||||||
|
|
||||||
// Map float values to uint such that comparisons still work.
|
// Map float values to uint such that comparisons still work.
|
||||||
// Positive values set the high bit, negative values are inverted.
|
// Positive values set the high bit, negative values are inverted.
|
||||||
|
|
@ -156,6 +157,11 @@ void topk(const uint row) {
|
||||||
// We need to compact these values to the start of the dst_row array.
|
// We need to compact these values to the start of the dst_row array.
|
||||||
// Have each subgroup count how many items it'll store, so other
|
// Have each subgroup count how many items it'll store, so other
|
||||||
// subgroups can compute their base offset.
|
// subgroups can compute their base offset.
|
||||||
|
// Values strictly greater than range_min must be stored. For values equal
|
||||||
|
// to range_min, there can be ties and it's possible we'll need to store
|
||||||
|
// an arbitrary subset of them.
|
||||||
|
// If total == p.k, have a fast path where we don't need to handle ties.
|
||||||
|
if (total == p.k) {
|
||||||
bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
|
bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
|
||||||
uvec4 b = subgroupBallot(top);
|
uvec4 b = subgroupBallot(top);
|
||||||
uint bit_count = subgroupBallotBitCount(b);
|
uint bit_count = subgroupBallotBitCount(b);
|
||||||
|
|
@ -176,6 +182,42 @@ void topk(const uint row) {
|
||||||
// TODO: Copy directly to the output?
|
// TODO: Copy directly to the output?
|
||||||
dst_row[out_idx + bit_count_ex] = v;
|
dst_row[out_idx + bit_count_ex] = v;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
bool top = f2ui(intBitsToFloat(v.y)) > range_min;
|
||||||
|
bool eq_min = f2ui(intBitsToFloat(v.y)) == range_min;
|
||||||
|
uvec4 b_top = subgroupBallot(top);
|
||||||
|
uvec4 b_eq_min = subgroupBallot(eq_min);
|
||||||
|
uint bit_count_top = subgroupBallotBitCount(b_top);
|
||||||
|
uint bit_count_eq_min = subgroupBallotBitCount(b_eq_min);
|
||||||
|
if ((tid % SUBGROUP_SIZE) == 0) {
|
||||||
|
offset_partials[tid / SUBGROUP_SIZE] = bit_count_top;
|
||||||
|
eq_min_partials[tid / SUBGROUP_SIZE] = bit_count_eq_min;
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
uint out_idx = 0;
|
||||||
|
uint eq_min_base = 0;
|
||||||
|
uint eq_min_idx = 0;
|
||||||
|
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
|
||||||
|
if (i < tid / SUBGROUP_SIZE) {
|
||||||
|
out_idx += offset_partials[i];
|
||||||
|
eq_min_idx += eq_min_partials[i];
|
||||||
|
}
|
||||||
|
eq_min_base += offset_partials[i];
|
||||||
|
}
|
||||||
|
// range_min values are stored at the end
|
||||||
|
eq_min_idx += eq_min_base;
|
||||||
|
|
||||||
|
uint bit_count_ex_top = subgroupBallotExclusiveBitCount(b_top);
|
||||||
|
uint bit_count_ex_eq_min = subgroupBallotExclusiveBitCount(b_eq_min);
|
||||||
|
if (top) {
|
||||||
|
// TODO: Copy directly to the output?
|
||||||
|
dst_row[out_idx + bit_count_ex_top] = v;
|
||||||
|
}
|
||||||
|
if (eq_min && eq_min_idx + bit_count_ex_eq_min < p.k) {
|
||||||
|
dst_row[eq_min_idx + bit_count_ex_eq_min] = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
barrier();
|
barrier();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -19,6 +19,15 @@ def parse_decls(decls_text):
|
||||||
return decls
|
return decls
|
||||||
|
|
||||||
|
|
||||||
|
def replace_repl_placeholders(variant, template_map):
|
||||||
|
for repl, code in variant["REPLS"].items():
|
||||||
|
for key, val in template_map.items():
|
||||||
|
# Match "key" and avoid matching subsequences using by using \b
|
||||||
|
code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code)
|
||||||
|
variant["REPLS"][repl] = code
|
||||||
|
return variant
|
||||||
|
|
||||||
|
|
||||||
def replace_placeholders(shader_text, replacements):
|
def replace_placeholders(shader_text, replacements):
|
||||||
for key, val in replacements.items():
|
for key, val in replacements.items():
|
||||||
# Match {{KEY}} literally, where KEY is escaped
|
# Match {{KEY}} literally, where KEY is escaped
|
||||||
|
|
@ -71,6 +80,10 @@ def generate_variants(fname, input_dir, output_dir, outfile):
|
||||||
decls_map = parse_decls(extract_block(text, "DECLS"))
|
decls_map = parse_decls(extract_block(text, "DECLS"))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
decls_map = {}
|
decls_map = {}
|
||||||
|
try:
|
||||||
|
templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES"))
|
||||||
|
except ValueError:
|
||||||
|
templates_map = {}
|
||||||
|
|
||||||
for fname in sorted(os.listdir(input_dir)):
|
for fname in sorted(os.listdir(input_dir)):
|
||||||
if fname.endswith(".tmpl"):
|
if fname.endswith(".tmpl"):
|
||||||
|
|
@ -90,9 +103,11 @@ def generate_variants(fname, input_dir, output_dir, outfile):
|
||||||
if key not in decls_map:
|
if key not in decls_map:
|
||||||
raise ValueError(f"DECLS key '{key}' not found.")
|
raise ValueError(f"DECLS key '{key}' not found.")
|
||||||
decls_code += decls_map[key] + "\n\n"
|
decls_code += decls_map[key] + "\n\n"
|
||||||
|
|
||||||
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
|
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
|
||||||
if "REPLS" in variant:
|
if "REPLS" in variant:
|
||||||
|
variant = replace_repl_placeholders(variant, templates_map)
|
||||||
|
final_shader = replace_placeholders(final_shader, variant["REPLS"])
|
||||||
|
# second run to expand placeholders in repl_template
|
||||||
final_shader = replace_placeholders(final_shader, variant["REPLS"])
|
final_shader = replace_placeholders(final_shader, variant["REPLS"])
|
||||||
final_shader = expand_includes(final_shader, input_dir)
|
final_shader = expand_includes(final_shader, input_dir)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,461 @@
|
||||||
|
#define(REPL_TEMPLATES)
|
||||||
|
|
||||||
|
{
|
||||||
|
"XIELU_FUNC": "{{MUTATE}}[dst_i] = select(((exp(min(src[src_i], {{TYPE}}(params.eps))) - 1.0) - src[src_i]) * {{TYPE}}(params.alpha_n) + {{TYPE}}(params.beta) * src[src_i], {{TYPE}}(params.alpha_p) * src[src_i] * src[src_i] + {{TYPE}}(params.beta) * src[src_i], src[src_i] > 0.0);",
|
||||||
|
"ABS_FUNC": "{{MUTATE}}[dst_i] = abs(src[src_i]);",
|
||||||
|
"SGN_FUNC": "{{MUTATE}}[dst_i] = select({{TYPE}}(select(0.0, -1.0, src[src_i] < 0.0)), {{TYPE}}(1.0), src[src_i] > 0.0);",
|
||||||
|
"NEG_FUNC": "{{MUTATE}}[dst_i] = -src[src_i];",
|
||||||
|
"STEP_FUNC": "{{MUTATE}}[dst_i] = {{TYPE}}(select(0.0, 1.0, src[src_i] > 0.0));",
|
||||||
|
"TANH_FUNC": "{{MUTATE}}[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
|
||||||
|
"RELU_FUNC": "{{MUTATE}}[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);",
|
||||||
|
"ELU_FUNC": "{{MUTATE}}[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);",
|
||||||
|
"HARDSIGMOID_FUNC": "{{MUTATE}}[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));",
|
||||||
|
"SIGMOID_FUNC": "{{MUTATE}}[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));",
|
||||||
|
"SILU_FUNC": "{{MUTATE}}[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));",
|
||||||
|
"EXP_FUNC": "{{MUTATE}}[dst_i] = exp(src[src_i]);",
|
||||||
|
"HARDSWISH_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));",
|
||||||
|
"GELU_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
|
||||||
|
"GELU_QUICK_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
|
||||||
|
"GELU_ERF_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458"
|
||||||
|
}
|
||||||
|
|
||||||
|
#end(REPL_TEMPLATES)
|
||||||
|
|
||||||
|
#define(VARIANTS)
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "abs_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "abs_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "abs_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "abs_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sgn_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sgn_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sgn_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sgn_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "neg_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "neg_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "neg_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "neg_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "step_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "step_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "step_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "step_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "tanh_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "tanh_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "tanh_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "tanh_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "elu_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "elu_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "elu_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "elu_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "relu_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "relu_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "relu_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "relu_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sigmoid_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sigmoid_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sigmoid_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sigmoid_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "silu_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "silu_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "silu_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "silu_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "exp_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "exp_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "exp_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "exp_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "hardsigmoid_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "hardsigmoid_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "hardsigmoid_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "hardsigmoid_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "hardswish_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "hardswish_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "hardswish_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "hardswish_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_quick_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_quick_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_quick_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_quick_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "xielu_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "xielu_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "xielu_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "xielu_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_erf_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_erf_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_erf_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_erf_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(DECLS)
|
||||||
|
|
||||||
|
#decl(INPLACE)
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
#enddecl(INPLACE)
|
||||||
|
|
||||||
|
#decl(NOT_INPLACE)
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
#enddecl(NOT_INPLACE)
|
||||||
|
|
||||||
|
#end(DECLS)
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
|
||||||
|
enable f16;
|
||||||
|
|
||||||
|
fn update(dst_i: u32, src_i: u32) {
|
||||||
|
{{FUNC}}
|
||||||
|
}
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> src: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
DECLS
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
ne: u32, // total number of elements
|
||||||
|
offset_src: u32, // in elements
|
||||||
|
offset_dst: u32, // in elements
|
||||||
|
|
||||||
|
// Strides (in elements) — may be permuted
|
||||||
|
stride_src0: u32,
|
||||||
|
stride_src1: u32,
|
||||||
|
stride_src2: u32,
|
||||||
|
stride_src3: u32,
|
||||||
|
|
||||||
|
stride_dst0: u32,
|
||||||
|
stride_dst1: u32,
|
||||||
|
stride_dst2: u32,
|
||||||
|
stride_dst3: u32,
|
||||||
|
|
||||||
|
// Logical shapes
|
||||||
|
src_ne0: u32,
|
||||||
|
src_ne1: u32,
|
||||||
|
src_ne2: u32,
|
||||||
|
|
||||||
|
dst_ne0: u32,
|
||||||
|
dst_ne1: u32,
|
||||||
|
dst_ne2: u32,
|
||||||
|
|
||||||
|
{{EXT_PARAMS}}
|
||||||
|
};
|
||||||
|
|
||||||
|
override wg_size: u32;
|
||||||
|
@compute @workgroup_size(wg_size)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
if (gid.x >= params.ne) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
var i = gid.x;
|
||||||
|
let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
|
||||||
|
i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
|
||||||
|
let i2 = i / (params.src_ne1 * params.src_ne0);
|
||||||
|
i = i % (params.src_ne1 * params.src_ne0);
|
||||||
|
let i1 = i / params.src_ne0;
|
||||||
|
let i0 = i % params.src_ne0;
|
||||||
|
|
||||||
|
var j = gid.x;
|
||||||
|
let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
|
||||||
|
j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
|
||||||
|
let j2 = j / (params.dst_ne1 * params.dst_ne0);
|
||||||
|
j = j % (params.dst_ne1 * params.dst_ne0);
|
||||||
|
let j1 = j / params.dst_ne0;
|
||||||
|
let j0 = j % params.dst_ne0;
|
||||||
|
|
||||||
|
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
|
||||||
|
i2 * params.stride_src2 + i3 * params.stride_src3;
|
||||||
|
|
||||||
|
let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
|
||||||
|
j2 * params.stride_dst2 + j3 * params.stride_dst3;
|
||||||
|
|
||||||
|
|
||||||
|
update(params.offset_dst + dst_idx, params.offset_src + src_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
#end(SHADER)
|
||||||
|
|
||||||
|
|
@ -4947,6 +4947,18 @@ struct ggml_tensor * ggml_pad(
|
||||||
return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
|
return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_pad_circular
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_pad_circular(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int p0,
|
||||||
|
int p1,
|
||||||
|
int p2,
|
||||||
|
int p3) {
|
||||||
|
return ggml_pad_ext_circular(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_pad_ext(
|
struct ggml_tensor * ggml_pad_ext(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
|
@ -4973,6 +4985,7 @@ struct ggml_tensor * ggml_pad_ext(
|
||||||
ggml_set_op_params_i32(result, 5, rp2);
|
ggml_set_op_params_i32(result, 5, rp2);
|
||||||
ggml_set_op_params_i32(result, 6, lp3);
|
ggml_set_op_params_i32(result, 6, lp3);
|
||||||
ggml_set_op_params_i32(result, 7, rp3);
|
ggml_set_op_params_i32(result, 7, rp3);
|
||||||
|
ggml_set_op_params_i32(result, 8, 0); // not circular by default
|
||||||
|
|
||||||
|
|
||||||
result->op = GGML_OP_PAD;
|
result->op = GGML_OP_PAD;
|
||||||
|
|
@ -4981,6 +4994,25 @@ struct ggml_tensor * ggml_pad_ext(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_pad_ext_circular
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_pad_ext_circular(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int lp0,
|
||||||
|
int rp0,
|
||||||
|
int lp1,
|
||||||
|
int rp1,
|
||||||
|
int lp2,
|
||||||
|
int rp2,
|
||||||
|
int lp3,
|
||||||
|
int rp3
|
||||||
|
) {
|
||||||
|
struct ggml_tensor * result = ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
|
||||||
|
ggml_set_op_params_i32(result, 8, 1); // circular
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_pad_reflect_1d
|
// ggml_pad_reflect_1d
|
||||||
|
|
||||||
struct ggml_tensor * ggml_pad_reflect_1d(
|
struct ggml_tensor * ggml_pad_reflect_1d(
|
||||||
|
|
|
||||||
|
|
@ -376,6 +376,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
|
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
|
||||||
"model.layers.{bid}.feed_forward.gate", # lfm2moe
|
"model.layers.{bid}.feed_forward.gate", # lfm2moe
|
||||||
"model.layers.{bid}.mlp.router.gate", # afmoe
|
"model.layers.{bid}.mlp.router.gate", # afmoe
|
||||||
|
"layers.{bid}.gate", # mistral-large
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||||
|
|
@ -450,6 +451,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
|
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
|
||||||
"model.layers.{bid}.feed_forward.down_proj",
|
"model.layers.{bid}.feed_forward.down_proj",
|
||||||
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
|
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
|
||||||
|
"layers.{bid}.shared_experts.w3", # mistral-large
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_UP_CHEXP: (
|
MODEL_TENSOR.FFN_UP_CHEXP: (
|
||||||
|
|
@ -496,6 +498,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
|
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
|
||||||
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
|
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
|
||||||
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
|
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
|
||||||
|
"layers.{bid}.shared_experts.w1", # mistral-large
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_GATE_CHEXP: (
|
MODEL_TENSOR.FFN_GATE_CHEXP: (
|
||||||
|
|
@ -557,6 +560,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
|
"model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
|
||||||
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
|
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
|
||||||
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
|
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
|
||||||
|
"layers.{bid}.shared_experts.w2", # mistral-large
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_DOWN_CHEXP: (
|
MODEL_TENSOR.FFN_DOWN_CHEXP: (
|
||||||
|
|
@ -924,14 +928,17 @@ class TensorNameMap:
|
||||||
|
|
||||||
MODEL_TENSOR.ATTN_Q_A: (
|
MODEL_TENSOR.ATTN_Q_A: (
|
||||||
"model.layers.{bid}.self_attn.q_a_proj", # deepseek2
|
"model.layers.{bid}.self_attn.q_a_proj", # deepseek2
|
||||||
|
"layers.{bid}.attention.wq_a", # mistral-large
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.ATTN_Q_B: (
|
MODEL_TENSOR.ATTN_Q_B: (
|
||||||
"model.layers.{bid}.self_attn.q_b_proj", # deepseek2
|
"model.layers.{bid}.self_attn.q_b_proj", # deepseek2
|
||||||
|
"layers.{bid}.attention.wq_b", # mistral-large
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.ATTN_KV_A_MQA: (
|
MODEL_TENSOR.ATTN_KV_A_MQA: (
|
||||||
"model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2
|
"model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2
|
||||||
|
"layers.{bid}.attention.wkv_a_with_mqa", # mistral-large
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.ATTN_KV_B: (
|
MODEL_TENSOR.ATTN_KV_B: (
|
||||||
|
|
@ -940,18 +947,22 @@ class TensorNameMap:
|
||||||
|
|
||||||
MODEL_TENSOR.ATTN_K_B: (
|
MODEL_TENSOR.ATTN_K_B: (
|
||||||
"model.layers.{bid}.self_attn.k_b_proj", # deepseek2
|
"model.layers.{bid}.self_attn.k_b_proj", # deepseek2
|
||||||
|
"layers.{bid}.attention.k_b_proj", # mistral-large
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.ATTN_V_B: (
|
MODEL_TENSOR.ATTN_V_B: (
|
||||||
"model.layers.{bid}.self_attn.v_b_proj", # deepseek2
|
"model.layers.{bid}.self_attn.v_b_proj", # deepseek2
|
||||||
|
"layers.{bid}.attention.v_b_proj", # mistral-large
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.ATTN_Q_A_NORM: (
|
MODEL_TENSOR.ATTN_Q_A_NORM: (
|
||||||
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
|
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
|
||||||
|
"layers.{bid}.attention.q_a_norm", # mistral-large
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.ATTN_KV_A_NORM: (
|
MODEL_TENSOR.ATTN_KV_A_NORM: (
|
||||||
"model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2
|
"model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2
|
||||||
|
"layers.{bid}.attention.kv_a_norm", # mistral-large
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.ATTN_SUB_NORM: (
|
MODEL_TENSOR.ATTN_SUB_NORM: (
|
||||||
|
|
|
||||||
|
|
@ -666,7 +666,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
|
|
||||||
std::map<int, std::string> mapped;
|
std::map<int, std::string> mapped;
|
||||||
int blk_id = 0;
|
int blk_id = 0;
|
||||||
int pruned_attention_w = 0;
|
|
||||||
|
|
||||||
// make a list of weights
|
// make a list of weights
|
||||||
std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
|
std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
|
||||||
|
|
@ -674,11 +673,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
for (const auto & it : ml.weights_map) {
|
for (const auto & it : ml.weights_map) {
|
||||||
const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id));
|
const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id));
|
||||||
if (remapped_name.empty()) {
|
if (remapped_name.empty()) {
|
||||||
if (it.first.find("attn_v.weight") != std::string::npos ||
|
|
||||||
it.first.find("attn_qkv.weight") != std::string::npos ||
|
|
||||||
it.first.find("attn_kv_b.weight") != std::string::npos) {
|
|
||||||
pruned_attention_w++;
|
|
||||||
}
|
|
||||||
LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
|
LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
@ -703,7 +697,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_clip_model = false;
|
|
||||||
for (const auto * it : tensors) {
|
for (const auto * it : tensors) {
|
||||||
const struct ggml_tensor * tensor = it->tensor;
|
const struct ggml_tensor * tensor = it->tensor;
|
||||||
|
|
||||||
|
|
@ -717,30 +710,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
} else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
|
} else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
|
||||||
qs.has_output = true;
|
qs.has_output = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
is_clip_model |= name.rfind("mm.", 0) == 0; // check the "mm." prefix
|
|
||||||
}
|
}
|
||||||
|
|
||||||
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
|
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
|
||||||
|
|
||||||
// sanity checks for models that have attention layers
|
|
||||||
if (qs.n_attention_wv != 0 && !is_clip_model)
|
|
||||||
{
|
|
||||||
int32_t n_layer_all = model.hparams.n_layer;
|
|
||||||
if (llama_model_has_encoder(&model)) {
|
|
||||||
// now n_layer_all is the number of attention layers in the encoder
|
|
||||||
// for each decoder block, there are 2 attention layers
|
|
||||||
n_layer_all += 2 * model.hparams.dec_n_layer;
|
|
||||||
}
|
|
||||||
|
|
||||||
// note: for linear-attention models (such as Qwen3 Next) this is the number of linear layers
|
|
||||||
const int32_t n_layer_recr = std::count(model.hparams.recurrent_layer_arr.begin(), model.hparams.recurrent_layer_arr.end(), true);
|
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: n_layer_all = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_all, n_layer_recr, pruned_attention_w);
|
|
||||||
|
|
||||||
GGML_ASSERT((qs.n_attention_wv == n_layer_all - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected");
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t total_size_org = 0;
|
size_t total_size_org = 0;
|
||||||
size_t total_size_new = 0;
|
size_t total_size_new = 0;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -286,10 +286,11 @@ static double nmse(const float * a, const float * b, size_t n) {
|
||||||
return mse_a_b / mse_a_0;
|
return mse_a_b / mse_a_0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// difference between 2 integer sets (Jaccard distance, 0 - no difference, 1 - no overlap)
|
// difference between 2 sets (Jaccard distance, 0 - no difference, 1 - no overlap)
|
||||||
static double jdst(const int32_t * a, const int32_t * b, size_t n) {
|
template <typename T>
|
||||||
std::unordered_map<int32_t, size_t> set_a;
|
static double jdst(const T * a, const T * b, size_t n) {
|
||||||
std::unordered_map<int32_t, size_t> set_b;
|
std::unordered_map<T, size_t> set_a;
|
||||||
|
std::unordered_map<T, size_t> set_b;
|
||||||
|
|
||||||
for (size_t i = 0; i < n; ++i) {
|
for (size_t i = 0; i < n; ++i) {
|
||||||
set_a[a[i]]++;
|
set_a[a[i]]++;
|
||||||
|
|
@ -5001,21 +5002,69 @@ struct test_top_k : public test_case {
|
||||||
const ggml_type type;
|
const ggml_type type;
|
||||||
const std::array<int64_t, 4> ne;
|
const std::array<int64_t, 4> ne;
|
||||||
const int k;
|
const int k;
|
||||||
|
const bool ties;
|
||||||
|
ggml_tensor * input {};
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR3(type, ne, k);
|
return VARS_TO_STR4(type, ne, k, ties);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_top_k(ggml_type type = GGML_TYPE_F32,
|
test_top_k(ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne = {16, 10, 10, 10},
|
std::array<int64_t, 4> ne = {16, 10, 10, 10},
|
||||||
int k = 4)
|
int k = 4, bool ties = false)
|
||||||
: type(type), ne(ne), k(k) {}
|
: type(type), ne(ne), k(k), ties(ties) {}
|
||||||
|
|
||||||
double max_err() override {
|
double max_err() override {
|
||||||
return 0.0;
|
return 0.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// When there are ties, only validate the final result.
|
||||||
|
// The logic in err can't handle the sentinel tensors.
|
||||||
|
bool run_whole_graph() override { return ties; }
|
||||||
|
|
||||||
double err(const float * a, const float * b, size_t n) override {
|
double err(const float * a, const float * b, size_t n) override {
|
||||||
|
// When there are no ties, we expect the exact same set of indices,
|
||||||
|
// but possibly in a different order. When there are ties, the indices
|
||||||
|
// can be different but the input values they correspond to should be
|
||||||
|
// the same. The logic for ties could work for non-ties, but only for
|
||||||
|
// the output tensor, not for the sentinel tensors.
|
||||||
|
if (ties) {
|
||||||
|
std::vector<float> src(ggml_nelements(input));
|
||||||
|
|
||||||
|
ggml_backend_tensor_get(input, src.data(), 0, ggml_nelements(input) * ggml_type_size(type));
|
||||||
|
|
||||||
|
double diff = 0.0f;
|
||||||
|
|
||||||
|
GGML_ASSERT(n == (size_t)(ggml_nrows(input) * k));
|
||||||
|
int64_t cols = input->ne[0];
|
||||||
|
std::vector<int32_t> ia(k);
|
||||||
|
std::vector<int32_t> ib(k);
|
||||||
|
std::vector<float> asrc(k);
|
||||||
|
std::vector<float> bsrc(k);
|
||||||
|
for (int64_t r = 0; r < ggml_nrows(input); r++) {
|
||||||
|
// Convert indices for the row back to integer
|
||||||
|
for (int64_t c = 0; c < k; c++) {
|
||||||
|
ia[c] = (int32_t)a[r * k + c];
|
||||||
|
ib[c] = (int32_t)b[r * k + c];
|
||||||
|
}
|
||||||
|
// The src values for each row should match.
|
||||||
|
for (int64_t c = 0; c < k; c++) {
|
||||||
|
asrc[c] = src[r * cols + ia[c]];
|
||||||
|
bsrc[c] = src[r * cols + ib[c]];
|
||||||
|
}
|
||||||
|
diff += jdst(asrc.data(), bsrc.data(), k);
|
||||||
|
// There should be no duplicate indices
|
||||||
|
std::sort(ia.begin(), ia.end());
|
||||||
|
std::sort(ib.begin(), ib.end());
|
||||||
|
if (std::adjacent_find(ia.begin(), ia.end()) != ia.end()) {
|
||||||
|
diff += 1;
|
||||||
|
}
|
||||||
|
if (std::adjacent_find(ib.begin(), ib.end()) != ib.end()) {
|
||||||
|
diff += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return diff;
|
||||||
|
} else {
|
||||||
std::vector<int32_t> ia(n);
|
std::vector<int32_t> ia(n);
|
||||||
std::vector<int32_t> ib(n);
|
std::vector<int32_t> ib(n);
|
||||||
|
|
||||||
|
|
@ -5032,11 +5081,15 @@ struct test_top_k : public test_case {
|
||||||
|
|
||||||
return diff + jdst(ia.data(), ib.data(), n);
|
return diff + jdst(ia.data(), ib.data(), n);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
ggml_set_name(a, "a");
|
ggml_set_name(a, "a");
|
||||||
|
|
||||||
|
// Save 'a' for err()
|
||||||
|
input = a;
|
||||||
|
|
||||||
ggml_tensor * out = ggml_top_k(ctx, a, k);
|
ggml_tensor * out = ggml_top_k(ctx, a, k);
|
||||||
ggml_set_name(out, "out");
|
ggml_set_name(out, "out");
|
||||||
|
|
||||||
|
|
@ -5047,12 +5100,17 @@ struct test_top_k : public test_case {
|
||||||
std::random_device rd;
|
std::random_device rd;
|
||||||
std::default_random_engine rng(rd());
|
std::default_random_engine rng(rd());
|
||||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||||
// initialize with unique values to avoid ties
|
int tie_denom = std::max(1, std::min(10, k / 2));
|
||||||
for (int64_t r = 0; r < ggml_nrows(t); r++) {
|
for (int64_t r = 0; r < ggml_nrows(t); r++) {
|
||||||
std::vector<float> data(t->ne[0]);
|
std::vector<float> data(t->ne[0]);
|
||||||
for (int i = 0; i < t->ne[0]; i++) {
|
for (int i = 0; i < t->ne[0]; i++) {
|
||||||
|
if (ties) {
|
||||||
|
// integer division to introduce duplicates
|
||||||
|
data[i] = i / tie_denom;
|
||||||
|
} else {
|
||||||
data[i] = i;
|
data[i] = i;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
std::shuffle(data.begin(), data.end(), rng);
|
std::shuffle(data.begin(), data.end(), rng);
|
||||||
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
|
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
|
||||||
}
|
}
|
||||||
|
|
@ -5546,21 +5604,24 @@ struct test_pad : public test_case {
|
||||||
const std::array<int64_t, 4> ne_a;
|
const std::array<int64_t, 4> ne_a;
|
||||||
const int pad_0;
|
const int pad_0;
|
||||||
const int pad_1;
|
const int pad_1;
|
||||||
|
const bool circular;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR4(type, ne_a, pad_0, pad_1);
|
return VARS_TO_STR5(type, ne_a, pad_0, pad_1, circular);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_pad(ggml_type type = GGML_TYPE_F32,
|
test_pad(ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne_a = {512, 512, 1, 1},
|
std::array<int64_t, 4> ne_a = {512, 512, 1, 1},
|
||||||
int pad_0 = 1, int pad_1 = 1)
|
int pad_0 = 1, int pad_1 = 1, bool circular = false)
|
||||||
: type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1) {}
|
: type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1), circular(circular) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||||
ggml_set_name(a, "a");
|
ggml_set_name(a, "a");
|
||||||
|
|
||||||
ggml_tensor * out = ggml_pad(ctx, a, pad_0, pad_1, 0, 0);
|
ggml_tensor * out = circular
|
||||||
|
? ggml_pad_circular(ctx, a, pad_0, pad_1, 0, 0)
|
||||||
|
: ggml_pad(ctx, a, pad_0, pad_1, 0, 0);
|
||||||
ggml_set_name(out, "out");
|
ggml_set_name(out, "out");
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
|
|
@ -5580,17 +5641,19 @@ struct test_pad_ext : public test_case {
|
||||||
const int lp3;
|
const int lp3;
|
||||||
const int rp3;
|
const int rp3;
|
||||||
const bool v;
|
const bool v;
|
||||||
|
const bool circular;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR11(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, v);
|
return VARS_TO_STR12(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, v, circular);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_pad_ext(ggml_type type = GGML_TYPE_F32,
|
test_pad_ext(ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne_a = {512, 512, 3, 1},
|
std::array<int64_t, 4> ne_a = {512, 512, 3, 1},
|
||||||
int lp0 = 1, int rp0 = 1, int lp1 = 1, int rp1 = 1,
|
int lp0 = 1, int rp0 = 1, int lp1 = 1, int rp1 = 1,
|
||||||
int lp2 = 1, int rp2 = 1, int lp3 = 1, int rp3 = 1,
|
int lp2 = 1, int rp2 = 1, int lp3 = 1, int rp3 = 1,
|
||||||
bool v = false)
|
bool v = false, bool circular = false)
|
||||||
: type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3), v(v) {}
|
: type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3),
|
||||||
|
v(v), circular(circular) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||||
|
|
@ -5601,7 +5664,9 @@ struct test_pad_ext : public test_case {
|
||||||
ggml_set_name(a, "view of a");
|
ggml_set_name(a, "view of a");
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * out = ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
|
ggml_tensor * out = circular
|
||||||
|
? ggml_pad_ext_circular(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3)
|
||||||
|
: ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
|
||||||
ggml_set_name(out, "out");
|
ggml_set_name(out, "out");
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
|
|
@ -6146,6 +6211,15 @@ struct test_solve_tri : public test_case {
|
||||||
|
|
||||||
std::string vars() override { return VARS_TO_STR3(type, ne_lhs, ne_rhs); }
|
std::string vars() override { return VARS_TO_STR3(type, ne_lhs, ne_rhs); }
|
||||||
|
|
||||||
|
uint64_t op_flops(ggml_tensor * t) override {
|
||||||
|
GGML_UNUSED(t);
|
||||||
|
int64_t n = ne_lhs[0];
|
||||||
|
int64_t k = ne_rhs[0];
|
||||||
|
int64_t batch = ne_lhs[2] * ne_lhs[3];
|
||||||
|
// n * (n + 1) / 2 non-zero elements of lhs, 2 flops each, for each col of rhs
|
||||||
|
return n * (n + 1) * k * batch;
|
||||||
|
}
|
||||||
|
|
||||||
test_solve_tri(ggml_type type = GGML_TYPE_F32,
|
test_solve_tri(ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne_lhs = { 10, 10, 4, 3 },
|
std::array<int64_t, 4> ne_lhs = { 10, 10, 4, 3 },
|
||||||
std::array<int64_t, 4> ne_rhs = { 3, 10, 4, 3 }
|
std::array<int64_t, 4> ne_rhs = { 3, 10, 4, 3 }
|
||||||
|
|
@ -6982,6 +7056,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
|
|
||||||
test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1));
|
test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1));
|
||||||
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
|
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
|
||||||
|
test_cases.emplace_back(new test_conv_transpose_2d({129, 63, 35, 1}, {3, 3, 48, 35}, 1));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1}));
|
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1}));
|
||||||
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
|
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
|
||||||
|
|
@ -7656,6 +7731,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
if (k <= 1<<i) {
|
if (k <= 1<<i) {
|
||||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i), 1, 1, 1}, k));
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i), 1, 1, 1}, k));
|
||||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k));
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k));
|
||||||
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k, true));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -7713,6 +7789,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {9, 9, 1280, 1}));
|
test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {9, 9, 1280, 1}));
|
||||||
test_cases.emplace_back(new test_acc());
|
test_cases.emplace_back(new test_acc());
|
||||||
test_cases.emplace_back(new test_pad());
|
test_cases.emplace_back(new test_pad());
|
||||||
|
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {33, 17, 2, 1}, 4, 3, true)); // circular
|
||||||
test_cases.emplace_back(new test_pad_ext());
|
test_cases.emplace_back(new test_pad_ext());
|
||||||
test_cases.emplace_back(new test_pad_reflect_1d());
|
test_cases.emplace_back(new test_pad_reflect_1d());
|
||||||
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
|
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
|
||||||
|
|
@ -7757,10 +7834,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 42, 42, 5, 2 }, { 10, 42, 5, 2 }));
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 42, 42, 5, 2 }, { 10, 42, 5, 2 }));
|
||||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 10, 64, 2, 2 }));
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 10, 64, 2, 2 }));
|
||||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 100, 100, 4, 4 }, { 41, 100, 4, 4 }));
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 100, 100, 4, 4 }, { 41, 100, 4, 4 }));
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 31, 128, 4, 4 }));
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 300, 64, 4, 4 }));
|
||||||
|
|
||||||
for (bool v : {false, true}) {
|
for (bool v : {false, true}) {
|
||||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v));
|
for (bool circular : {false, true}) {
|
||||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v));
|
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v, circular));
|
||||||
|
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v, circular));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) {
|
for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) {
|
||||||
|
|
@ -7898,6 +7979,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||||
{ 58, 3, 64, 32, 8 },
|
{ 58, 3, 64, 32, 8 },
|
||||||
// A deep layer of a ConvNet, several images in the batch
|
// A deep layer of a ConvNet, several images in the batch
|
||||||
{ 16, 3, 512, 128, 8 },
|
{ 16, 3, 512, 128, 8 },
|
||||||
|
// High resolution output (large NPQ)
|
||||||
|
{1536, 3, 64, 32, 1 },
|
||||||
};
|
};
|
||||||
|
|
||||||
for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||||
|
|
@ -7955,6 +8038,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||||
|
|
||||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
|
||||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
|
||||||
|
// qwen3next with CHUNK_SIZE 64
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 8, 32 }, { 64, 64, 8, 32 }));
|
||||||
|
// qwen3next with CHUNK_SIZE 128
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 32 }, { 128, 128, 4, 32 }));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 }));
|
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 }));
|
||||||
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));
|
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));
|
||||||
|
|
|
||||||
Binary file not shown.
|
|
@ -3,6 +3,7 @@
|
||||||
import { copyToClipboard, isIMEComposing } from '$lib/utils';
|
import { copyToClipboard, isIMEComposing } from '$lib/utils';
|
||||||
import ChatMessageAssistant from './ChatMessageAssistant.svelte';
|
import ChatMessageAssistant from './ChatMessageAssistant.svelte';
|
||||||
import ChatMessageUser from './ChatMessageUser.svelte';
|
import ChatMessageUser from './ChatMessageUser.svelte';
|
||||||
|
import ChatMessageSystem from './ChatMessageSystem.svelte';
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
class?: string;
|
class?: string;
|
||||||
|
|
@ -140,8 +141,7 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleSaveEdit() {
|
function handleSaveEdit() {
|
||||||
if (message.role === 'user') {
|
if (message.role === 'user' || message.role === 'system') {
|
||||||
// For user messages, trim to avoid accidental whitespace
|
|
||||||
onEditWithBranching?.(message, editedContent.trim());
|
onEditWithBranching?.(message, editedContent.trim());
|
||||||
} else {
|
} else {
|
||||||
// For assistant messages, preserve exact content including trailing whitespace
|
// For assistant messages, preserve exact content including trailing whitespace
|
||||||
|
|
@ -167,7 +167,28 @@
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
{#if message.role === 'user'}
|
{#if message.role === 'system'}
|
||||||
|
<ChatMessageSystem
|
||||||
|
bind:textareaElement
|
||||||
|
class={className}
|
||||||
|
{deletionInfo}
|
||||||
|
{editedContent}
|
||||||
|
{isEditing}
|
||||||
|
{message}
|
||||||
|
onCancelEdit={handleCancelEdit}
|
||||||
|
onConfirmDelete={handleConfirmDelete}
|
||||||
|
onCopy={handleCopy}
|
||||||
|
onDelete={handleDelete}
|
||||||
|
onEdit={handleEdit}
|
||||||
|
onEditKeydown={handleEditKeydown}
|
||||||
|
onEditedContentChange={handleEditedContentChange}
|
||||||
|
{onNavigateToSibling}
|
||||||
|
onSaveEdit={handleSaveEdit}
|
||||||
|
onShowDeleteDialogChange={handleShowDeleteDialogChange}
|
||||||
|
{showDeleteDialog}
|
||||||
|
{siblingInfo}
|
||||||
|
/>
|
||||||
|
{:else if message.role === 'user'}
|
||||||
<ChatMessageUser
|
<ChatMessageUser
|
||||||
bind:textareaElement
|
bind:textareaElement
|
||||||
class={className}
|
class={className}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,216 @@
|
||||||
|
<script lang="ts">
|
||||||
|
import { Check, X } from '@lucide/svelte';
|
||||||
|
import { Card } from '$lib/components/ui/card';
|
||||||
|
import { Button } from '$lib/components/ui/button';
|
||||||
|
import { MarkdownContent } from '$lib/components/app';
|
||||||
|
import { INPUT_CLASSES } from '$lib/constants/input-classes';
|
||||||
|
import { config } from '$lib/stores/settings.svelte';
|
||||||
|
import ChatMessageActions from './ChatMessageActions.svelte';
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
class?: string;
|
||||||
|
message: DatabaseMessage;
|
||||||
|
isEditing: boolean;
|
||||||
|
editedContent: string;
|
||||||
|
siblingInfo?: ChatMessageSiblingInfo | null;
|
||||||
|
showDeleteDialog: boolean;
|
||||||
|
deletionInfo: {
|
||||||
|
totalCount: number;
|
||||||
|
userMessages: number;
|
||||||
|
assistantMessages: number;
|
||||||
|
messageTypes: string[];
|
||||||
|
} | null;
|
||||||
|
onCancelEdit: () => void;
|
||||||
|
onSaveEdit: () => void;
|
||||||
|
onEditKeydown: (event: KeyboardEvent) => void;
|
||||||
|
onEditedContentChange: (content: string) => void;
|
||||||
|
onCopy: () => void;
|
||||||
|
onEdit: () => void;
|
||||||
|
onDelete: () => void;
|
||||||
|
onConfirmDelete: () => void;
|
||||||
|
onNavigateToSibling?: (siblingId: string) => void;
|
||||||
|
onShowDeleteDialogChange: (show: boolean) => void;
|
||||||
|
textareaElement?: HTMLTextAreaElement;
|
||||||
|
}
|
||||||
|
|
||||||
|
let {
|
||||||
|
class: className = '',
|
||||||
|
message,
|
||||||
|
isEditing,
|
||||||
|
editedContent,
|
||||||
|
siblingInfo = null,
|
||||||
|
showDeleteDialog,
|
||||||
|
deletionInfo,
|
||||||
|
onCancelEdit,
|
||||||
|
onSaveEdit,
|
||||||
|
onEditKeydown,
|
||||||
|
onEditedContentChange,
|
||||||
|
onCopy,
|
||||||
|
onEdit,
|
||||||
|
onDelete,
|
||||||
|
onConfirmDelete,
|
||||||
|
onNavigateToSibling,
|
||||||
|
onShowDeleteDialogChange,
|
||||||
|
textareaElement = $bindable()
|
||||||
|
}: Props = $props();
|
||||||
|
|
||||||
|
let isMultiline = $state(false);
|
||||||
|
let messageElement: HTMLElement | undefined = $state();
|
||||||
|
let isExpanded = $state(false);
|
||||||
|
let contentHeight = $state(0);
|
||||||
|
const MAX_HEIGHT = 200; // pixels
|
||||||
|
const currentConfig = config();
|
||||||
|
|
||||||
|
let showExpandButton = $derived(contentHeight > MAX_HEIGHT);
|
||||||
|
|
||||||
|
$effect(() => {
|
||||||
|
if (!messageElement || !message.content.trim()) return;
|
||||||
|
|
||||||
|
if (message.content.includes('\n')) {
|
||||||
|
isMultiline = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const resizeObserver = new ResizeObserver((entries) => {
|
||||||
|
for (const entry of entries) {
|
||||||
|
const element = entry.target as HTMLElement;
|
||||||
|
const estimatedSingleLineHeight = 24;
|
||||||
|
|
||||||
|
isMultiline = element.offsetHeight > estimatedSingleLineHeight * 1.5;
|
||||||
|
contentHeight = element.scrollHeight;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
resizeObserver.observe(messageElement);
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
resizeObserver.disconnect();
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
function toggleExpand() {
|
||||||
|
isExpanded = !isExpanded;
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<div
|
||||||
|
aria-label="System message with actions"
|
||||||
|
class="group flex flex-col items-end gap-3 md:gap-2 {className}"
|
||||||
|
role="group"
|
||||||
|
>
|
||||||
|
{#if isEditing}
|
||||||
|
<div class="w-full max-w-[80%]">
|
||||||
|
<textarea
|
||||||
|
bind:this={textareaElement}
|
||||||
|
bind:value={editedContent}
|
||||||
|
class="min-h-[60px] w-full resize-none rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}"
|
||||||
|
onkeydown={onEditKeydown}
|
||||||
|
oninput={(e) => onEditedContentChange(e.currentTarget.value)}
|
||||||
|
placeholder="Edit system message..."
|
||||||
|
></textarea>
|
||||||
|
|
||||||
|
<div class="mt-2 flex justify-end gap-2">
|
||||||
|
<Button class="h-8 px-3" onclick={onCancelEdit} size="sm" variant="outline">
|
||||||
|
<X class="mr-1 h-3 w-3" />
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
|
||||||
|
<Button class="h-8 px-3" onclick={onSaveEdit} disabled={!editedContent.trim()} size="sm">
|
||||||
|
<Check class="mr-1 h-3 w-3" />
|
||||||
|
Send
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{:else}
|
||||||
|
{#if message.content.trim()}
|
||||||
|
<div class="relative max-w-[80%]">
|
||||||
|
<button
|
||||||
|
class="group/expand w-full text-left {!isExpanded && showExpandButton
|
||||||
|
? 'cursor-pointer'
|
||||||
|
: 'cursor-auto'}"
|
||||||
|
onclick={showExpandButton && !isExpanded ? toggleExpand : undefined}
|
||||||
|
type="button"
|
||||||
|
>
|
||||||
|
<Card
|
||||||
|
class="rounded-[1.125rem] !border-2 !border-dashed !border-border/50 bg-muted px-3.75 py-1.5 data-[multiline]:py-2.5"
|
||||||
|
data-multiline={isMultiline ? '' : undefined}
|
||||||
|
style="border: 2px dashed hsl(var(--border));"
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
class="relative overflow-hidden transition-all duration-300 {isExpanded
|
||||||
|
? 'cursor-text select-text'
|
||||||
|
: 'select-none'}"
|
||||||
|
style={!isExpanded && showExpandButton
|
||||||
|
? `max-height: ${MAX_HEIGHT}px;`
|
||||||
|
: 'max-height: none;'}
|
||||||
|
>
|
||||||
|
{#if currentConfig.renderUserContentAsMarkdown}
|
||||||
|
<div bind:this={messageElement} class="text-md {isExpanded ? 'cursor-text' : ''}">
|
||||||
|
<MarkdownContent class="markdown-system-content" content={message.content} />
|
||||||
|
</div>
|
||||||
|
{:else}
|
||||||
|
<span
|
||||||
|
bind:this={messageElement}
|
||||||
|
class="text-md whitespace-pre-wrap {isExpanded ? 'cursor-text' : ''}"
|
||||||
|
>
|
||||||
|
{message.content}
|
||||||
|
</span>
|
||||||
|
{/if}
|
||||||
|
|
||||||
|
{#if !isExpanded && showExpandButton}
|
||||||
|
<div
|
||||||
|
class="pointer-events-none absolute right-0 bottom-0 left-0 h-48 bg-gradient-to-t from-muted to-transparent"
|
||||||
|
></div>
|
||||||
|
<div
|
||||||
|
class="pointer-events-none absolute right-0 bottom-4 left-0 flex justify-center opacity-0 transition-opacity group-hover/expand:opacity-100"
|
||||||
|
>
|
||||||
|
<Button
|
||||||
|
class="rounded-full px-4 py-1.5 text-xs shadow-md"
|
||||||
|
size="sm"
|
||||||
|
variant="outline"
|
||||||
|
>
|
||||||
|
Show full system message
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{#if isExpanded && showExpandButton}
|
||||||
|
<div class="mb-2 flex justify-center">
|
||||||
|
<Button
|
||||||
|
class="rounded-full px-4 py-1.5 text-xs"
|
||||||
|
onclick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
toggleExpand();
|
||||||
|
}}
|
||||||
|
size="sm"
|
||||||
|
variant="outline"
|
||||||
|
>
|
||||||
|
Collapse System Message
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
</Card>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
|
||||||
|
{#if message.timestamp}
|
||||||
|
<div class="max-w-[80%]">
|
||||||
|
<ChatMessageActions
|
||||||
|
actionsPosition="right"
|
||||||
|
{deletionInfo}
|
||||||
|
justify="end"
|
||||||
|
{onConfirmDelete}
|
||||||
|
{onCopy}
|
||||||
|
{onDelete}
|
||||||
|
{onEdit}
|
||||||
|
{onNavigateToSibling}
|
||||||
|
{onShowDeleteDialogChange}
|
||||||
|
{siblingInfo}
|
||||||
|
{showDeleteDialog}
|
||||||
|
role="user"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
|
@ -145,7 +145,7 @@
|
||||||
|
|
||||||
{#if message.content.trim()}
|
{#if message.content.trim()}
|
||||||
<Card
|
<Card
|
||||||
class="max-w-[80%] rounded-[1.125rem] bg-primary px-3.75 py-1.5 text-primary-foreground data-[multiline]:py-2.5"
|
class="max-w-[80%] rounded-[1.125rem] border-none bg-primary px-3.75 py-1.5 text-primary-foreground data-[multiline]:py-2.5"
|
||||||
data-multiline={isMultiline ? '' : undefined}
|
data-multiline={isMultiline ? '' : undefined}
|
||||||
>
|
>
|
||||||
{#if currentConfig.renderUserContentAsMarkdown}
|
{#if currentConfig.renderUserContentAsMarkdown}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
import { ChatMessage } from '$lib/components/app';
|
import { ChatMessage } from '$lib/components/app';
|
||||||
import { chatStore } from '$lib/stores/chat.svelte';
|
import { chatStore } from '$lib/stores/chat.svelte';
|
||||||
import { conversationsStore, activeConversation } from '$lib/stores/conversations.svelte';
|
import { conversationsStore, activeConversation } from '$lib/stores/conversations.svelte';
|
||||||
|
import { config } from '$lib/stores/settings.svelte';
|
||||||
import { getMessageSiblings } from '$lib/utils';
|
import { getMessageSiblings } from '$lib/utils';
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
|
|
@ -13,6 +14,7 @@
|
||||||
let { class: className, messages = [], onUserAction }: Props = $props();
|
let { class: className, messages = [], onUserAction }: Props = $props();
|
||||||
|
|
||||||
let allConversationMessages = $state<DatabaseMessage[]>([]);
|
let allConversationMessages = $state<DatabaseMessage[]>([]);
|
||||||
|
const currentConfig = config();
|
||||||
|
|
||||||
function refreshAllMessages() {
|
function refreshAllMessages() {
|
||||||
const conversation = activeConversation();
|
const conversation = activeConversation();
|
||||||
|
|
@ -40,7 +42,12 @@
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
return messages.map((message) => {
|
// Filter out system messages if showSystemMessage is false
|
||||||
|
const filteredMessages = currentConfig.showSystemMessage
|
||||||
|
? messages
|
||||||
|
: messages.filter((msg) => msg.type !== 'system');
|
||||||
|
|
||||||
|
return filteredMessages.map((message) => {
|
||||||
const siblingInfo = getMessageSiblings(allConversationMessages, message.id);
|
const siblingInfo = getMessageSiblings(allConversationMessages, message.id);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -36,12 +36,6 @@
|
||||||
title: 'General',
|
title: 'General',
|
||||||
icon: Settings,
|
icon: Settings,
|
||||||
fields: [
|
fields: [
|
||||||
{ key: 'apiKey', label: 'API Key', type: 'input' },
|
|
||||||
{
|
|
||||||
key: 'systemMessage',
|
|
||||||
label: 'System Message (will be disabled if left empty)',
|
|
||||||
type: 'textarea'
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
key: 'theme',
|
key: 'theme',
|
||||||
label: 'Theme',
|
label: 'Theme',
|
||||||
|
|
@ -52,6 +46,12 @@
|
||||||
{ value: 'dark', label: 'Dark', icon: Moon }
|
{ value: 'dark', label: 'Dark', icon: Moon }
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{ key: 'apiKey', label: 'API Key', type: 'input' },
|
||||||
|
{
|
||||||
|
key: 'systemMessage',
|
||||||
|
label: 'System Message',
|
||||||
|
type: 'textarea'
|
||||||
|
},
|
||||||
{
|
{
|
||||||
key: 'pasteLongTextToFileLen',
|
key: 'pasteLongTextToFileLen',
|
||||||
label: 'Paste long text to file length',
|
label: 'Paste long text to file length',
|
||||||
|
|
|
||||||
|
|
@ -95,7 +95,7 @@
|
||||||
</div>
|
</div>
|
||||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||||
<p class="mt-1 text-xs text-muted-foreground">
|
<p class="mt-1 text-xs text-muted-foreground">
|
||||||
{field.help || SETTING_CONFIG_INFO[field.key]}
|
{@html field.help || SETTING_CONFIG_INFO[field.key]}
|
||||||
</p>
|
</p>
|
||||||
{/if}
|
{/if}
|
||||||
{:else if field.type === 'textarea'}
|
{:else if field.type === 'textarea'}
|
||||||
|
|
@ -112,13 +112,28 @@
|
||||||
value={String(localConfig[field.key] ?? '')}
|
value={String(localConfig[field.key] ?? '')}
|
||||||
onchange={(e) => onConfigChange(field.key, e.currentTarget.value)}
|
onchange={(e) => onConfigChange(field.key, e.currentTarget.value)}
|
||||||
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`}
|
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`}
|
||||||
class="min-h-[100px] w-full md:max-w-2xl"
|
class="min-h-[10rem] w-full md:max-w-2xl"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||||
<p class="mt-1 text-xs text-muted-foreground">
|
<p class="mt-1 text-xs text-muted-foreground">
|
||||||
{field.help || SETTING_CONFIG_INFO[field.key]}
|
{field.help || SETTING_CONFIG_INFO[field.key]}
|
||||||
</p>
|
</p>
|
||||||
{/if}
|
{/if}
|
||||||
|
|
||||||
|
{#if field.key === 'systemMessage'}
|
||||||
|
<div class="mt-3 flex items-center gap-2">
|
||||||
|
<Checkbox
|
||||||
|
id="showSystemMessage"
|
||||||
|
checked={Boolean(localConfig.showSystemMessage ?? true)}
|
||||||
|
onCheckedChange={(checked) => onConfigChange('showSystemMessage', Boolean(checked))}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<Label for="showSystemMessage" class="cursor-pointer text-sm font-normal">
|
||||||
|
Show system message in conversations
|
||||||
|
</Label>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
{:else if field.type === 'select'}
|
{:else if field.type === 'select'}
|
||||||
{@const selectedOption = field.options?.find(
|
{@const selectedOption = field.options?.find(
|
||||||
(opt: { value: string; label: string; icon?: Component }) =>
|
(opt: { value: string; label: string; icon?: Component }) =>
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@
|
||||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||||
import Input from '$lib/components/ui/input/input.svelte';
|
import Input from '$lib/components/ui/input/input.svelte';
|
||||||
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
|
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
|
||||||
|
import { chatStore } from '$lib/stores/chat.svelte';
|
||||||
import ChatSidebarActions from './ChatSidebarActions.svelte';
|
import ChatSidebarActions from './ChatSidebarActions.svelte';
|
||||||
|
|
||||||
const sidebar = Sidebar.useSidebar();
|
const sidebar = Sidebar.useSidebar();
|
||||||
|
|
@ -98,6 +99,10 @@
|
||||||
|
|
||||||
await goto(`#/chat/${id}`);
|
await goto(`#/chat/${id}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function handleStopGeneration(id: string) {
|
||||||
|
chatStore.stopGenerationForChat(id);
|
||||||
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<ScrollArea class="h-[100vh]">
|
<ScrollArea class="h-[100vh]">
|
||||||
|
|
@ -132,6 +137,7 @@
|
||||||
onSelect={selectConversation}
|
onSelect={selectConversation}
|
||||||
onEdit={handleEditConversation}
|
onEdit={handleEditConversation}
|
||||||
onDelete={handleDeleteConversation}
|
onDelete={handleDeleteConversation}
|
||||||
|
onStop={handleStopGeneration}
|
||||||
/>
|
/>
|
||||||
</Sidebar.MenuItem>
|
</Sidebar.MenuItem>
|
||||||
{/each}
|
{/each}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
import { Trash2, Pencil, MoreHorizontal, Download, Loader2 } from '@lucide/svelte';
|
import { Trash2, Pencil, MoreHorizontal, Download, Loader2, Square } from '@lucide/svelte';
|
||||||
import { ActionDropdown } from '$lib/components/app';
|
import { ActionDropdown } from '$lib/components/app';
|
||||||
|
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||||
import { getAllLoadingChats } from '$lib/stores/chat.svelte';
|
import { getAllLoadingChats } from '$lib/stores/chat.svelte';
|
||||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||||
import { onMount } from 'svelte';
|
import { onMount } from 'svelte';
|
||||||
|
|
@ -12,6 +13,7 @@
|
||||||
onDelete?: (id: string) => void;
|
onDelete?: (id: string) => void;
|
||||||
onEdit?: (id: string) => void;
|
onEdit?: (id: string) => void;
|
||||||
onSelect?: (id: string) => void;
|
onSelect?: (id: string) => void;
|
||||||
|
onStop?: (id: string) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
let {
|
let {
|
||||||
|
|
@ -20,6 +22,7 @@
|
||||||
onDelete,
|
onDelete,
|
||||||
onEdit,
|
onEdit,
|
||||||
onSelect,
|
onSelect,
|
||||||
|
onStop,
|
||||||
isActive = false
|
isActive = false
|
||||||
}: Props = $props();
|
}: Props = $props();
|
||||||
|
|
||||||
|
|
@ -38,8 +41,14 @@
|
||||||
onDelete?.(conversation.id);
|
onDelete?.(conversation.id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function handleStop(event: Event) {
|
||||||
|
event.stopPropagation();
|
||||||
|
onStop?.(conversation.id);
|
||||||
|
}
|
||||||
|
|
||||||
function handleGlobalEditEvent(event: Event) {
|
function handleGlobalEditEvent(event: Event) {
|
||||||
const customEvent = event as CustomEvent<{ conversationId: string }>;
|
const customEvent = event as CustomEvent<{ conversationId: string }>;
|
||||||
|
|
||||||
if (customEvent.detail.conversationId === conversation.id && isActive) {
|
if (customEvent.detail.conversationId === conversation.id && isActive) {
|
||||||
handleEdit(event);
|
handleEdit(event);
|
||||||
}
|
}
|
||||||
|
|
@ -88,8 +97,28 @@
|
||||||
>
|
>
|
||||||
<div class="flex min-w-0 flex-1 items-center gap-2">
|
<div class="flex min-w-0 flex-1 items-center gap-2">
|
||||||
{#if isLoading}
|
{#if isLoading}
|
||||||
<Loader2 class="h-3.5 w-3.5 shrink-0 animate-spin text-muted-foreground" />
|
<Tooltip.Root>
|
||||||
|
<Tooltip.Trigger>
|
||||||
|
<div
|
||||||
|
class="stop-button flex h-4 w-4 shrink-0 cursor-pointer items-center justify-center rounded text-muted-foreground transition-colors hover:text-foreground"
|
||||||
|
onclick={handleStop}
|
||||||
|
onkeydown={(e) => e.key === 'Enter' && handleStop(e)}
|
||||||
|
role="button"
|
||||||
|
tabindex="0"
|
||||||
|
aria-label="Stop generation"
|
||||||
|
>
|
||||||
|
<Loader2 class="loading-icon h-3.5 w-3.5 animate-spin" />
|
||||||
|
|
||||||
|
<Square class="stop-icon hidden h-3 w-3 fill-current text-destructive" />
|
||||||
|
</div>
|
||||||
|
</Tooltip.Trigger>
|
||||||
|
|
||||||
|
<Tooltip.Content>
|
||||||
|
<p>Stop generation</p>
|
||||||
|
</Tooltip.Content>
|
||||||
|
</Tooltip.Root>
|
||||||
{/if}
|
{/if}
|
||||||
|
|
||||||
<!-- svelte-ignore a11y_click_events_have_key_events -->
|
<!-- svelte-ignore a11y_click_events_have_key_events -->
|
||||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||||
<span class="truncate text-sm font-medium" onclick={handleMobileSidebarItemClick}>
|
<span class="truncate text-sm font-medium" onclick={handleMobileSidebarItemClick}>
|
||||||
|
|
@ -147,5 +176,25 @@
|
||||||
opacity: 1 !important;
|
opacity: 1 !important;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.stop-button {
|
||||||
|
:global(.stop-icon) {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
:global(.loading-icon) {
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
&:is(:hover) .stop-button {
|
||||||
|
:global(.stop-icon) {
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
|
||||||
|
:global(.loading-icon) {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,10 @@ export { default as ChatMessage } from './chat/ChatMessages/ChatMessage.svelte';
|
||||||
export { default as ChatMessageActions } from './chat/ChatMessages/ChatMessageActions.svelte';
|
export { default as ChatMessageActions } from './chat/ChatMessages/ChatMessageActions.svelte';
|
||||||
export { default as ChatMessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte';
|
export { default as ChatMessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte';
|
||||||
export { default as ChatMessageStatistics } from './chat/ChatMessages/ChatMessageStatistics.svelte';
|
export { default as ChatMessageStatistics } from './chat/ChatMessages/ChatMessageStatistics.svelte';
|
||||||
|
export { default as ChatMessageSystem } from './chat/ChatMessages/ChatMessageSystem.svelte';
|
||||||
export { default as ChatMessageThinkingBlock } from './chat/ChatMessages/ChatMessageThinkingBlock.svelte';
|
export { default as ChatMessageThinkingBlock } from './chat/ChatMessages/ChatMessageThinkingBlock.svelte';
|
||||||
export { default as ChatMessages } from './chat/ChatMessages/ChatMessages.svelte';
|
export { default as ChatMessages } from './chat/ChatMessages/ChatMessages.svelte';
|
||||||
|
export { default as MessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte';
|
||||||
|
|
||||||
export { default as ChatScreen } from './chat/ChatScreen/ChatScreen.svelte';
|
export { default as ChatScreen } from './chat/ChatScreen/ChatScreen.svelte';
|
||||||
export { default as ChatScreenHeader } from './chat/ChatScreen/ChatScreenHeader.svelte';
|
export { default as ChatScreenHeader } from './chat/ChatScreen/ChatScreenHeader.svelte';
|
||||||
|
|
|
||||||
|
|
@ -337,19 +337,23 @@
|
||||||
line-height: 1.75;
|
line-height: 1.75;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
div :global(:is(h1, h2, h3, h4, h5, h6):first-child) {
|
||||||
|
margin-top: 0;
|
||||||
|
}
|
||||||
|
|
||||||
/* Headers with consistent spacing */
|
/* Headers with consistent spacing */
|
||||||
div :global(h1) {
|
div :global(h1) {
|
||||||
font-size: 1.875rem;
|
font-size: 1.875rem;
|
||||||
font-weight: 700;
|
font-weight: 700;
|
||||||
margin: 1.5rem 0 0.75rem 0;
|
|
||||||
line-height: 1.2;
|
line-height: 1.2;
|
||||||
|
margin: 1.5rem 0 0.75rem 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
div :global(h2) {
|
div :global(h2) {
|
||||||
font-size: 1.5rem;
|
font-size: 1.5rem;
|
||||||
font-weight: 600;
|
font-weight: 600;
|
||||||
margin: 1.25rem 0 0.5rem 0;
|
|
||||||
line-height: 1.3;
|
line-height: 1.3;
|
||||||
|
margin: 1.25rem 0 0.5rem 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
div :global(h3) {
|
div :global(h3) {
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
|
||||||
// Do not use nested objects, keep it single level. Prefix the key if you need to group them.
|
// Do not use nested objects, keep it single level. Prefix the key if you need to group them.
|
||||||
apiKey: '',
|
apiKey: '',
|
||||||
systemMessage: '',
|
systemMessage: '',
|
||||||
|
showSystemMessage: true,
|
||||||
theme: 'system',
|
theme: 'system',
|
||||||
showThoughtInProgress: false,
|
showThoughtInProgress: false,
|
||||||
showToolCalls: false,
|
showToolCalls: false,
|
||||||
|
|
@ -43,8 +44,9 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
|
||||||
};
|
};
|
||||||
|
|
||||||
export const SETTING_CONFIG_INFO: Record<string, string> = {
|
export const SETTING_CONFIG_INFO: Record<string, string> = {
|
||||||
apiKey: 'Set the API Key if you are using --api-key option for the server.',
|
apiKey: 'Set the API Key if you are using <code>--api-key</code> option for the server.',
|
||||||
systemMessage: 'The starting message that defines how model should behave.',
|
systemMessage: 'The starting message that defines how model should behave.',
|
||||||
|
showSystemMessage: 'Display the system message at the top of each conversation.',
|
||||||
theme:
|
theme:
|
||||||
'Choose the color theme for the interface. You can choose between System (follows your device settings), Light, or Dark.',
|
'Choose the color theme for the interface. You can choose between System (follows your device settings), Light, or Dark.',
|
||||||
pasteLongTextToFileLen:
|
pasteLongTextToFileLen:
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,6 @@ export class ChatService {
|
||||||
custom,
|
custom,
|
||||||
timings_per_token,
|
timings_per_token,
|
||||||
// Config options
|
// Config options
|
||||||
systemMessage,
|
|
||||||
disableReasoningFormat
|
disableReasoningFormat
|
||||||
} = options;
|
} = options;
|
||||||
|
|
||||||
|
|
@ -104,6 +103,7 @@ export class ChatService {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.filter((msg) => {
|
.filter((msg) => {
|
||||||
|
// Filter out empty system messages
|
||||||
if (msg.role === 'system') {
|
if (msg.role === 'system') {
|
||||||
const content = typeof msg.content === 'string' ? msg.content : '';
|
const content = typeof msg.content === 'string' ? msg.content : '';
|
||||||
|
|
||||||
|
|
@ -113,10 +113,8 @@ export class ChatService {
|
||||||
return true;
|
return true;
|
||||||
});
|
});
|
||||||
|
|
||||||
const processedMessages = ChatService.injectSystemMessage(normalizedMessages, systemMessage);
|
|
||||||
|
|
||||||
const requestBody: ApiChatCompletionRequest = {
|
const requestBody: ApiChatCompletionRequest = {
|
||||||
messages: processedMessages.map((msg: ApiChatMessageData) => ({
|
messages: normalizedMessages.map((msg: ApiChatMessageData) => ({
|
||||||
role: msg.role,
|
role: msg.role,
|
||||||
content: msg.content
|
content: msg.content
|
||||||
})),
|
})),
|
||||||
|
|
@ -680,46 +678,6 @@ export class ChatService {
|
||||||
// Utilities
|
// Utilities
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
/**
|
|
||||||
* Injects a system message at the beginning of the conversation if provided.
|
|
||||||
* Checks for existing system messages to avoid duplication.
|
|
||||||
*
|
|
||||||
* @param messages - Array of chat messages to process
|
|
||||||
* @param systemMessage - Optional system message to inject
|
|
||||||
* @returns Array of messages with system message injected at the beginning if provided
|
|
||||||
* @private
|
|
||||||
*/
|
|
||||||
private static injectSystemMessage(
|
|
||||||
messages: ApiChatMessageData[],
|
|
||||||
systemMessage?: string
|
|
||||||
): ApiChatMessageData[] {
|
|
||||||
const trimmedSystemMessage = systemMessage?.trim();
|
|
||||||
|
|
||||||
if (!trimmedSystemMessage) {
|
|
||||||
return messages;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (messages.length > 0 && messages[0].role === 'system') {
|
|
||||||
if (messages[0].content !== trimmedSystemMessage) {
|
|
||||||
const updatedMessages = [...messages];
|
|
||||||
updatedMessages[0] = {
|
|
||||||
role: 'system',
|
|
||||||
content: trimmedSystemMessage
|
|
||||||
};
|
|
||||||
return updatedMessages;
|
|
||||||
}
|
|
||||||
|
|
||||||
return messages;
|
|
||||||
}
|
|
||||||
|
|
||||||
const systemMsg: ApiChatMessageData = {
|
|
||||||
role: 'system',
|
|
||||||
content: trimmedSystemMessage
|
|
||||||
};
|
|
||||||
|
|
||||||
return [systemMsg, ...messages];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Parses error response and creates appropriate error with context information
|
* Parses error response and creates appropriate error with context information
|
||||||
* @param response - HTTP response object
|
* @param response - HTTP response object
|
||||||
|
|
|
||||||
|
|
@ -166,6 +166,49 @@ export class DatabaseService {
|
||||||
return rootMessage.id;
|
return rootMessage.id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a system prompt message for a conversation.
|
||||||
|
*
|
||||||
|
* @param convId - Conversation ID
|
||||||
|
* @param systemPrompt - The system prompt content (must be non-empty)
|
||||||
|
* @param parentId - Parent message ID (typically the root message)
|
||||||
|
* @returns The created system message
|
||||||
|
* @throws Error if systemPrompt is empty
|
||||||
|
*/
|
||||||
|
static async createSystemMessage(
|
||||||
|
convId: string,
|
||||||
|
systemPrompt: string,
|
||||||
|
parentId: string
|
||||||
|
): Promise<DatabaseMessage> {
|
||||||
|
const trimmedPrompt = systemPrompt.trim();
|
||||||
|
if (!trimmedPrompt) {
|
||||||
|
throw new Error('Cannot create system message with empty content');
|
||||||
|
}
|
||||||
|
|
||||||
|
const systemMessage: DatabaseMessage = {
|
||||||
|
id: uuid(),
|
||||||
|
convId,
|
||||||
|
type: 'system',
|
||||||
|
timestamp: Date.now(),
|
||||||
|
role: 'system',
|
||||||
|
content: trimmedPrompt,
|
||||||
|
parent: parentId,
|
||||||
|
thinking: '',
|
||||||
|
children: []
|
||||||
|
};
|
||||||
|
|
||||||
|
await db.messages.add(systemMessage);
|
||||||
|
|
||||||
|
const parentMessage = await db.messages.get(parentId);
|
||||||
|
if (parentMessage) {
|
||||||
|
await db.messages.update(parentId, {
|
||||||
|
children: [...parentMessage.children, systemMessage.id]
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return systemMessage;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Deletes a conversation and all its messages.
|
* Deletes a conversation and all its messages.
|
||||||
*
|
*
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,11 @@ import { DatabaseService, ChatService } from '$lib/services';
|
||||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||||
import { config } from '$lib/stores/settings.svelte';
|
import { config } from '$lib/stores/settings.svelte';
|
||||||
import { contextSize, isRouterMode } from '$lib/stores/server.svelte';
|
import { contextSize, isRouterMode } from '$lib/stores/server.svelte';
|
||||||
import { selectedModelName, modelsStore } from '$lib/stores/models.svelte';
|
import {
|
||||||
|
selectedModelName,
|
||||||
|
modelsStore,
|
||||||
|
selectedModelContextSize
|
||||||
|
} from '$lib/stores/models.svelte';
|
||||||
import {
|
import {
|
||||||
normalizeModelName,
|
normalizeModelName,
|
||||||
filterByLeafNodeId,
|
filterByLeafNodeId,
|
||||||
|
|
@ -261,6 +265,13 @@ class ChatStore {
|
||||||
return activeState.contextTotal;
|
return activeState.contextTotal;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isRouterMode()) {
|
||||||
|
const modelContextSize = selectedModelContextSize();
|
||||||
|
if (modelContextSize && modelContextSize > 0) {
|
||||||
|
return modelContextSize;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const propsContextSize = contextSize();
|
const propsContextSize = contextSize();
|
||||||
if (propsContextSize && propsContextSize > 0) {
|
if (propsContextSize && propsContextSize > 0) {
|
||||||
return propsContextSize;
|
return propsContextSize;
|
||||||
|
|
@ -458,6 +469,14 @@ class ChatStore {
|
||||||
onError?: (error: Error) => void,
|
onError?: (error: Error) => void,
|
||||||
modelOverride?: string | null
|
modelOverride?: string | null
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
|
// Ensure model props are cached before streaming (for correct n_ctx in processing info)
|
||||||
|
if (isRouterMode()) {
|
||||||
|
const modelName = modelOverride || selectedModelName();
|
||||||
|
if (modelName && !modelsStore.getModelProps(modelName)) {
|
||||||
|
await modelsStore.fetchModelProps(modelName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let streamedContent = '';
|
let streamedContent = '';
|
||||||
let streamedReasoningContent = '';
|
let streamedReasoningContent = '';
|
||||||
let streamedToolCallContent = '';
|
let streamedToolCallContent = '';
|
||||||
|
|
@ -624,6 +643,22 @@ class ChatStore {
|
||||||
this.clearChatStreaming(currentConv.id);
|
this.clearChatStreaming(currentConv.id);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
if (isNewConversation) {
|
||||||
|
const rootId = await DatabaseService.createRootMessage(currentConv.id);
|
||||||
|
const currentConfig = config();
|
||||||
|
const systemPrompt = currentConfig.systemMessage?.toString().trim();
|
||||||
|
|
||||||
|
if (systemPrompt) {
|
||||||
|
const systemMessage = await DatabaseService.createSystemMessage(
|
||||||
|
currentConv.id,
|
||||||
|
systemPrompt,
|
||||||
|
rootId
|
||||||
|
);
|
||||||
|
|
||||||
|
conversationsStore.addMessageToActive(systemMessage);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const userMessage = await this.addMessage('user', content, 'text', '-1', extras);
|
const userMessage = await this.addMessage('user', content, 'text', '-1', extras);
|
||||||
if (!userMessage) throw new Error('Failed to add user message');
|
if (!userMessage) throw new Error('Failed to add user message');
|
||||||
if (isNewConversation && content)
|
if (isNewConversation && content)
|
||||||
|
|
@ -666,13 +701,17 @@ class ChatStore {
|
||||||
|
|
||||||
if (!activeConv) return;
|
if (!activeConv) return;
|
||||||
|
|
||||||
await this.savePartialResponseIfNeeded(activeConv.id);
|
await this.stopGenerationForChat(activeConv.id);
|
||||||
|
}
|
||||||
|
|
||||||
|
async stopGenerationForChat(convId: string): Promise<void> {
|
||||||
|
await this.savePartialResponseIfNeeded(convId);
|
||||||
|
|
||||||
this.stopStreaming();
|
this.stopStreaming();
|
||||||
this.abortRequest(activeConv.id);
|
this.abortRequest(convId);
|
||||||
this.setChatLoading(activeConv.id, false);
|
this.setChatLoading(convId, false);
|
||||||
this.clearChatStreaming(activeConv.id);
|
this.clearChatStreaming(convId);
|
||||||
this.clearProcessingState(activeConv.id);
|
this.clearProcessingState(convId);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -999,14 +1038,20 @@ class ChatStore {
|
||||||
const activeConv = conversationsStore.activeConversation;
|
const activeConv = conversationsStore.activeConversation;
|
||||||
if (!activeConv || this.isLoading) return;
|
if (!activeConv || this.isLoading) return;
|
||||||
|
|
||||||
const result = this.getMessageByIdWithRole(messageId, 'user');
|
let result = this.getMessageByIdWithRole(messageId, 'user');
|
||||||
|
|
||||||
|
if (!result) {
|
||||||
|
result = this.getMessageByIdWithRole(messageId, 'system');
|
||||||
|
}
|
||||||
|
|
||||||
if (!result) return;
|
if (!result) return;
|
||||||
const { message: msg } = result;
|
const { message: msg } = result;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
||||||
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
|
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
|
||||||
const isFirstUserMessage = rootMessage && msg.parent === rootMessage.id;
|
const isFirstUserMessage =
|
||||||
|
msg.role === 'user' && rootMessage && msg.parent === rootMessage.id;
|
||||||
|
|
||||||
const parentId = msg.parent || rootMessage?.id;
|
const parentId = msg.parent || rootMessage?.id;
|
||||||
if (!parentId) return;
|
if (!parentId) return;
|
||||||
|
|
@ -1037,7 +1082,10 @@ class ChatStore {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
await conversationsStore.refreshActiveMessages();
|
await conversationsStore.refreshActiveMessages();
|
||||||
|
|
||||||
|
if (msg.role === 'user') {
|
||||||
await this.generateResponseForMessage(newMessage.id);
|
await this.generateResponseForMessage(newMessage.id);
|
||||||
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Failed to edit message with branching:', error);
|
console.error('Failed to edit message with branching:', error);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -158,6 +158,22 @@ class ModelsStore {
|
||||||
return this.modelPropsCache.get(modelId) ?? null;
|
return this.modelPropsCache.get(modelId) ?? null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get context size (n_ctx) for a specific model from cached props
|
||||||
|
*/
|
||||||
|
getModelContextSize(modelId: string): number | null {
|
||||||
|
const props = this.modelPropsCache.get(modelId);
|
||||||
|
return props?.default_generation_settings?.n_ctx ?? null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get context size for the currently selected model or null if no model is selected
|
||||||
|
*/
|
||||||
|
get selectedModelContextSize(): number | null {
|
||||||
|
if (!this.selectedModelName) return null;
|
||||||
|
return this.getModelContextSize(this.selectedModelName);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check if props are being fetched for a model
|
* Check if props are being fetched for a model
|
||||||
*/
|
*/
|
||||||
|
|
@ -579,3 +595,4 @@ export const loadedModelIds = () => modelsStore.loadedModelIds;
|
||||||
export const loadingModelIds = () => modelsStore.loadingModelIds;
|
export const loadingModelIds = () => modelsStore.loadingModelIds;
|
||||||
export const propsCacheVersion = () => modelsStore.propsCacheVersion;
|
export const propsCacheVersion = () => modelsStore.propsCacheVersion;
|
||||||
export const singleModelName = () => modelsStore.singleModelName;
|
export const singleModelName = () => modelsStore.singleModelName;
|
||||||
|
export const selectedModelContextSize = () => modelsStore.selectedModelContextSize;
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
export type ChatMessageType = 'root' | 'text' | 'think';
|
export type ChatMessageType = 'root' | 'text' | 'think' | 'system';
|
||||||
export type ChatRole = 'user' | 'assistant' | 'system';
|
export type ChatRole = 'user' | 'assistant' | 'system';
|
||||||
|
|
||||||
export interface ChatUploadedFile {
|
export interface ChatUploadedFile {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue