From 3c3635d2f20424d557b5b0605a2a356214ffe048 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Sat, 6 Sep 2025 19:45:24 +0700 Subject: [PATCH 01/46] server : speed up tests (#15836) * server : speed up tests * clean up * restore timeout_seconds in some places * flake8 * explicit offline --- scripts/tool_bench.py | 4 +- tools/server/tests/unit/test_basic.py | 6 ++ tools/server/tests/unit/test_template.py | 9 +-- tools/server/tests/unit/test_tool_call.py | 20 +++--- tools/server/tests/unit/test_vision_api.py | 77 +++++++++++++--------- tools/server/tests/utils.py | 24 ++++++- 6 files changed, 90 insertions(+), 50 deletions(-) diff --git a/scripts/tool_bench.py b/scripts/tool_bench.py index 05d6dfc30a..e1512a49fd 100755 --- a/scripts/tool_bench.py +++ b/scripts/tool_bench.py @@ -53,7 +53,7 @@ import typer sys.path.insert(0, Path(__file__).parent.parent.as_posix()) if True: from tools.server.tests.utils import ServerProcess - from tools.server.tests.unit.test_tool_call import TIMEOUT_SERVER_START, do_test_calc_result, do_test_hello_world, do_test_weather + from tools.server.tests.unit.test_tool_call import do_test_calc_result, do_test_hello_world, do_test_weather @contextmanager @@ -335,7 +335,7 @@ def run( # server.debug = True with scoped_server(server): - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start(timeout_seconds=15 * 60) for ignore_chat_grammar in [False]: run( server, diff --git a/tools/server/tests/unit/test_basic.py b/tools/server/tests/unit/test_basic.py index c7b3af0489..58ade52be6 100644 --- a/tools/server/tests/unit/test_basic.py +++ b/tools/server/tests/unit/test_basic.py @@ -5,6 +5,12 @@ from utils import * server = ServerPreset.tinyllama2() +@pytest.fixture(scope="session", autouse=True) +def do_something(): + # this will be run once per test session, before any tests + ServerPreset.load_all() + + @pytest.fixture(autouse=True) def create_server(): global server diff --git a/tools/server/tests/unit/test_template.py b/tools/server/tests/unit/test_template.py index c53eda5b88..e5185fcbfa 100644 --- a/tools/server/tests/unit/test_template.py +++ b/tools/server/tests/unit/test_template.py @@ -14,14 +14,11 @@ from utils import * server: ServerProcess -TIMEOUT_SERVER_START = 15*60 - @pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() server.model_alias = "tinyllama-2" - server.server_port = 8081 server.n_slots = 1 @@ -45,7 +42,7 @@ def test_reasoning_budget(template_name: str, reasoning_budget: int | None, expe server.jinja = True server.reasoning_budget = reasoning_budget server.chat_template_file = f'../../../models/templates/{template_name}.jinja' - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start() res = server.make_request("POST", "/apply-template", data={ "messages": [ @@ -68,7 +65,7 @@ def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]): global server server.jinja = True server.chat_template_file = f'../../../models/templates/{template_name}.jinja' - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start() res = server.make_request("POST", "/apply-template", data={ "messages": [ @@ -91,7 +88,7 @@ def test_add_generation_prompt(template_name: str, expected_generation_prompt: s global server server.jinja = True server.chat_template_file = f'../../../models/templates/{template_name}.jinja' - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start() res = server.make_request("POST", "/apply-template", data={ "messages": [ diff --git a/tools/server/tests/unit/test_tool_call.py b/tools/server/tests/unit/test_tool_call.py index a3c3ccdf58..b8f0f10863 100755 --- a/tools/server/tests/unit/test_tool_call.py +++ b/tools/server/tests/unit/test_tool_call.py @@ -12,7 +12,7 @@ from enum import Enum server: ServerProcess -TIMEOUT_SERVER_START = 15*60 +TIMEOUT_START_SLOW = 15 * 60 # this is needed for real model tests TIMEOUT_HTTP_REQUEST = 60 @pytest.fixture(autouse=True) @@ -124,7 +124,7 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, server.jinja = True server.n_predict = n_predict server.chat_template_file = f'../../../models/templates/{template_name}.jinja' - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start() do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0) @@ -168,7 +168,7 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, server.jinja = True server.n_predict = n_predict server.chat_template_file = f'../../../models/templates/{template_name}.jinja' - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start(timeout_seconds=TIMEOUT_START_SLOW) do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED) @@ -240,7 +240,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." elif isinstance(template_override, str): server.chat_template = template_override - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start(timeout_seconds=TIMEOUT_START_SLOW) body = server.make_any_request("POST", "/v1/chat/completions", data={ "max_tokens": n_predict, "messages": [ @@ -295,7 +295,7 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t server.n_predict = n_predict server.jinja = True server.chat_template_file = f'../../../models/templates/{template_name}.jinja' - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start() do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED) @@ -317,7 +317,7 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t server.n_predict = n_predict server.jinja = True server.chat_template_file = f'../../../models/templates/{template_name}.jinja' - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start(timeout_seconds=TIMEOUT_START_SLOW) do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED) @@ -377,7 +377,7 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." elif isinstance(template_override, str): server.chat_template = template_override - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start() do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict) @@ -436,7 +436,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." elif isinstance(template_override, str): server.chat_template = template_override - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start(timeout_seconds=TIMEOUT_START_SLOW) do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED) @@ -524,7 +524,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." elif isinstance(template_override, str): server.chat_template = template_override - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start() body = server.make_any_request("POST", "/v1/chat/completions", data={ "max_tokens": n_predict, "messages": [ @@ -597,7 +597,7 @@ def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | Non assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." elif isinstance(template_override, str): server.chat_template = template_override - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start(timeout_seconds=TIMEOUT_START_SLOW) do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict) diff --git a/tools/server/tests/unit/test_vision_api.py b/tools/server/tests/unit/test_vision_api.py index 36d14b3885..9408116d1c 100644 --- a/tools/server/tests/unit/test_vision_api.py +++ b/tools/server/tests/unit/test_vision_api.py @@ -5,18 +5,31 @@ import requests server: ServerProcess -IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" -IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png" - -response = requests.get(IMG_URL_0) -response.raise_for_status() # Raise an exception for bad status codes -IMG_BASE64_URI_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") -IMG_BASE64_0 = base64.b64encode(response.content).decode("utf-8") - -response = requests.get(IMG_URL_1) -response.raise_for_status() # Raise an exception for bad status codes -IMG_BASE64_URI_1 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") -IMG_BASE64_1 = base64.b64encode(response.content).decode("utf-8") +def get_img_url(id: str) -> str: + IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" + IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png" + if id == "IMG_URL_0": + return IMG_URL_0 + elif id == "IMG_URL_1": + return IMG_URL_1 + elif id == "IMG_BASE64_URI_0": + response = requests.get(IMG_URL_0) + response.raise_for_status() # Raise an exception for bad status codes + return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") + elif id == "IMG_BASE64_0": + response = requests.get(IMG_URL_0) + response.raise_for_status() # Raise an exception for bad status codes + return base64.b64encode(response.content).decode("utf-8") + elif id == "IMG_BASE64_URI_1": + response = requests.get(IMG_URL_1) + response.raise_for_status() # Raise an exception for bad status codes + return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") + elif id == "IMG_BASE64_1": + response = requests.get(IMG_URL_1) + response.raise_for_status() # Raise an exception for bad status codes + return base64.b64encode(response.content).decode("utf-8") + else: + return id JSON_MULTIMODAL_KEY = "multimodal_data" JSON_PROMPT_STRING_KEY = "prompt_string" @@ -28,7 +41,7 @@ def create_server(): def test_models_supports_multimodal_capability(): global server - server.start() # vision model may take longer to load due to download size + server.start() res = server.make_request("GET", "/models", data={}) assert res.status_code == 200 model_info = res.body["models"][0] @@ -38,7 +51,7 @@ def test_models_supports_multimodal_capability(): def test_v1_models_supports_multimodal_capability(): global server - server.start() # vision model may take longer to load due to download size + server.start() res = server.make_request("GET", "/v1/models", data={}) assert res.status_code == 200 model_info = res.body["models"][0] @@ -50,10 +63,10 @@ def test_v1_models_supports_multimodal_capability(): "prompt, image_url, success, re_content", [ # test model is trained on CIFAR-10, but it's quite dumb due to small size - ("What is this:\n", IMG_URL_0, True, "(cat)+"), - ("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"), # exceptional, so that we don't cog up the log - ("What is this:\n", IMG_URL_1, True, "(frog)+"), - ("Test test\n", IMG_URL_1, True, "(frog)+"), # test invalidate cache + ("What is this:\n", "IMG_URL_0", True, "(cat)+"), + ("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"), + ("What is this:\n", "IMG_URL_1", True, "(frog)+"), + ("Test test\n", "IMG_URL_1", True, "(frog)+"), # test invalidate cache ("What is this:\n", "malformed", False, None), ("What is this:\n", "https://google.com/404", False, None), # non-existent image ("What is this:\n", "https://ggml.ai", False, None), # non-image data @@ -62,9 +75,7 @@ def test_v1_models_supports_multimodal_capability(): ) def test_vision_chat_completion(prompt, image_url, success, re_content): global server - server.start(timeout_seconds=60) # vision model may take longer to load due to download size - if image_url == "IMG_BASE64_URI_0": - image_url = IMG_BASE64_URI_0 + server.start() res = server.make_request("POST", "/chat/completions", data={ "temperature": 0.0, "top_k": 1, @@ -72,7 +83,7 @@ def test_vision_chat_completion(prompt, image_url, success, re_content): {"role": "user", "content": [ {"type": "text", "text": prompt}, {"type": "image_url", "image_url": { - "url": image_url, + "url": get_img_url(image_url), }}, ]}, ], @@ -90,19 +101,22 @@ def test_vision_chat_completion(prompt, image_url, success, re_content): "prompt, image_data, success, re_content", [ # test model is trained on CIFAR-10, but it's quite dumb due to small size - ("What is this: <__media__>\n", IMG_BASE64_0, True, "(cat)+"), - ("What is this: <__media__>\n", IMG_BASE64_1, True, "(frog)+"), + ("What is this: <__media__>\n", "IMG_BASE64_0", True, "(cat)+"), + ("What is this: <__media__>\n", "IMG_BASE64_1", True, "(frog)+"), ("What is this: <__media__>\n", "malformed", False, None), # non-image data ("What is this:\n", "", False, None), # empty string ] ) def test_vision_completion(prompt, image_data, success, re_content): global server - server.start() # vision model may take longer to load due to download size + server.start() res = server.make_request("POST", "/completions", data={ "temperature": 0.0, "top_k": 1, - "prompt": { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] }, + "prompt": { + JSON_PROMPT_STRING_KEY: prompt, + JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ], + }, }) if success: assert res.status_code == 200 @@ -116,17 +130,18 @@ def test_vision_completion(prompt, image_data, success, re_content): "prompt, image_data, success", [ # test model is trained on CIFAR-10, but it's quite dumb due to small size - ("What is this: <__media__>\n", IMG_BASE64_0, True), # exceptional, so that we don't cog up the log - ("What is this: <__media__>\n", IMG_BASE64_1, True), + ("What is this: <__media__>\n", "IMG_BASE64_0", True), + ("What is this: <__media__>\n", "IMG_BASE64_1", True), ("What is this: <__media__>\n", "malformed", False), # non-image data ("What is this:\n", "base64", False), # non-image data ] ) def test_vision_embeddings(prompt, image_data, success): global server - server.server_embeddings=True - server.n_batch=512 - server.start() # vision model may take longer to load due to download size + server.server_embeddings = True + server.n_batch = 512 + server.start() + image_data = get_img_url(image_data) res = server.make_request("POST", "/embeddings", data={ "content": [ { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] }, diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index cda7434d7c..10997ef57c 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -26,7 +26,7 @@ from re import RegexFlag import wget -DEFAULT_HTTP_TIMEOUT = 30 +DEFAULT_HTTP_TIMEOUT = 60 class ServerResponse: @@ -45,6 +45,7 @@ class ServerProcess: model_alias: str = "tinyllama-2" temperature: float = 0.8 seed: int = 42 + offline: bool = False # custom options model_alias: str | None = None @@ -118,6 +119,8 @@ class ServerProcess: "--seed", self.seed, ] + if self.offline: + server_args.append("--offline") if self.model_file: server_args.extend(["--model", self.model_file]) if self.model_url: @@ -392,6 +395,19 @@ server_instances: Set[ServerProcess] = set() class ServerPreset: + @staticmethod + def load_all() -> None: + """ Load all server presets to ensure model files are cached. """ + servers: List[ServerProcess] = [ + method() + for name, method in ServerPreset.__dict__.items() + if callable(method) and name != "load_all" + ] + for server in servers: + server.offline = False + server.start() + server.stop() + @staticmethod def tinyllama2() -> ServerProcess: server = ServerProcess() @@ -408,6 +424,7 @@ class ServerPreset: @staticmethod def bert_bge_small() -> ServerProcess: server = ServerProcess() + server.offline = True # will be downloaded by load_all() server.model_hf_repo = "ggml-org/models" server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" server.model_alias = "bert-bge-small" @@ -422,6 +439,7 @@ class ServerPreset: @staticmethod def bert_bge_small_with_fa() -> ServerProcess: server = ServerProcess() + server.offline = True # will be downloaded by load_all() server.model_hf_repo = "ggml-org/models" server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" server.model_alias = "bert-bge-small" @@ -437,6 +455,7 @@ class ServerPreset: @staticmethod def tinyllama_infill() -> ServerProcess: server = ServerProcess() + server.offline = True # will be downloaded by load_all() server.model_hf_repo = "ggml-org/models" server.model_hf_file = "tinyllamas/stories260K-infill.gguf" server.model_alias = "tinyllama-infill" @@ -451,6 +470,7 @@ class ServerPreset: @staticmethod def stories15m_moe() -> ServerProcess: server = ServerProcess() + server.offline = True # will be downloaded by load_all() server.model_hf_repo = "ggml-org/stories15M_MOE" server.model_hf_file = "stories15M_MOE-F16.gguf" server.model_alias = "stories15m-moe" @@ -465,6 +485,7 @@ class ServerPreset: @staticmethod def jina_reranker_tiny() -> ServerProcess: server = ServerProcess() + server.offline = True # will be downloaded by load_all() server.model_hf_repo = "ggml-org/models" server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf" server.model_alias = "jina-reranker" @@ -478,6 +499,7 @@ class ServerPreset: @staticmethod def tinygemma3() -> ServerProcess: server = ServerProcess() + server.offline = True # will be downloaded by load_all() # mmproj is already provided by HF registry API server.model_hf_repo = "ggml-org/tinygemma3-GGUF" server.model_hf_file = "tinygemma3-Q8_0.gguf" From c4df49a42d396bdf7344501813e7de53bc9e7bb3 Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Sat, 6 Sep 2025 16:08:43 +0200 Subject: [PATCH 02/46] kleidiai: generalize compute_forward_kv_cache to compute_forward_fp16 (#15817) --- ggml/src/ggml-cpu/kleidiai/kleidiai.cpp | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 7a830448eb..95f873dc77 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -154,7 +154,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { if (dst->src[0]->type == GGML_TYPE_Q4_0) { return compute_forward_q4_0(params, dst); } else if (dst->src[0]->type == GGML_TYPE_F16) { - return compute_forward_kv_cache(params, dst); + return compute_forward_fp16(params, dst); } } else if (dst->op == GGML_OP_GET_ROWS) { if (dst->src[0]->type == GGML_TYPE_Q4_0) { @@ -164,7 +164,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { return false; } - bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) { + bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) { static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT; const ggml_tensor * src0 = dst->src[0]; @@ -534,13 +534,8 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; } - else if (ggml_kleidiai_select_kernels(ctx.features, op) && - op->src[0]->op == GGML_OP_VIEW && - (op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) && - op->src[1]->ne[1] > 1) { - if ((op->src[0]->nb[0] != 2) || - (op->src[1]->nb[0] != 4) || - (op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) || + else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) { + if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) || (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) { return nullptr; } From 79bc429262268ad2ac8a364cfe6c2d6b9c5f008a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 7 Sep 2025 00:26:28 +0200 Subject: [PATCH 03/46] CUDA: faster tile FA (Pascal/AMD), headsize 256 (#15769) --- ggml/src/ggml-cuda/fattn-tile-f16.cu | 371 ---------------- ggml/src/ggml-cuda/fattn-tile-f16.cuh | 3 - ggml/src/ggml-cuda/fattn-tile-f32.cu | 379 ---------------- ggml/src/ggml-cuda/fattn-tile-f32.cuh | 3 - ggml/src/ggml-cuda/fattn-tile.cu | 596 ++++++++++++++++++++++++++ ggml/src/ggml-cuda/fattn-tile.cuh | 3 + ggml/src/ggml-cuda/fattn.cu | 18 +- 7 files changed, 604 insertions(+), 769 deletions(-) delete mode 100644 ggml/src/ggml-cuda/fattn-tile-f16.cu delete mode 100644 ggml/src/ggml-cuda/fattn-tile-f16.cuh delete mode 100644 ggml/src/ggml-cuda/fattn-tile-f32.cu delete mode 100644 ggml/src/ggml-cuda/fattn-tile-f32.cuh create mode 100644 ggml/src/ggml-cuda/fattn-tile.cu create mode 100644 ggml/src/ggml-cuda/fattn-tile.cuh diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu deleted file mode 100644 index a900799a99..0000000000 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ /dev/null @@ -1,371 +0,0 @@ -#include "common.cuh" -#include "fattn-common.cuh" -#include "fattn-tile-f16.cuh" - -#define FATTN_KQ_STRIDE_TILE_F16 64 - -template // D == head size -#if !defined(GGML_USE_HIP) -__launch_bounds__(nwarps*WARP_SIZE, 2) -#endif // !defined(GGML_USE_HIP) -static __global__ void flash_attn_tile_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, - const int32_t nb01, const int32_t nb02, const int32_t nb03, - const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, - const int32_t nb11, const int32_t nb12, const int64_t nb13, - const int32_t nb21, const int32_t nb22, const int64_t nb23, - const int32_t ne31, const int32_t ne32, const int32_t ne33, - const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) - - // Skip unused kernel variants for faster compilation: -#ifdef FP16_MMA_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // FP16_MMA_AVAILABLE - if (use_logit_softcap && !(D == 128 || D == 256)) { - NO_DEVICE_CODE; - return; - } - - //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. - - const int sequence = blockIdx.z / ne02; - const int head = blockIdx.z - sequence*ne02; - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float * sinksf = (const float *) (sinks); - - const int stride_KV2 = nb11 / sizeof(half2); - - const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); - const half slopeh = __float2half(slopef); - - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - - __shared__ half KQ[ncols*FATTN_KQ_STRIDE_TILE_F16]; - half2 * KQ2 = (half2 *) KQ; - - __shared__ half2 KV_tmp[FATTN_KQ_STRIDE_TILE_F16][D/2 + 1]; // Pad D to avoid memory bank conflicts. - - half kqmax[ncols/nwarps]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - kqmax[j0/nwarps] = -HALF_MAX_HALF; - } - half2 kqsum[ncols/nwarps] = {{0.0f, 0.0f}}; - - half2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}}; - - // Convert Q to half2 and store in registers: - __shared__ half2 Q_h2[ncols][D/2]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f); - Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); - } - } - - __syncthreads(); - - const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; - for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) { - // Calculate KQ tile and keep track of new maximum KQ values: - - half kqmax_new[ncols/nwarps]; -#pragma unroll - for (int j = 0; j < ncols/nwarps; ++j) { - kqmax_new[j] = kqmax[j]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; - } - } - - __syncthreads(); - - half2 sum2[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE][ncols/nwarps] = {{{0.0f, 0.0f}}}; - -#pragma unroll - for (int k_KQ = 0; k_KQ < D/2; ++k_KQ) { - half2 K_k[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE]; - half2 Q_k[ncols/nwarps]; - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) { - const int i_KQ = i_KQ_0 + threadIdx.x; - - K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ]; - } -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; - - Q_k[j_KQ_0/nwarps] = Q_h2[j_KQ][k_KQ]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) { -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE]*Q_k[j_KQ_0/nwarps]; - } - } - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) { - const int i_KQ = i_KQ_0 + threadIdx.x; - -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; - - half sum; - if (use_logit_softcap) { - const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); - sum = logit_softcap * tanhf(tmp.x + tmp.y); - } else { - sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); - } - sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); - - kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum); - - KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F16 + i_KQ] = sum; - } - } - - __syncthreads(); - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]); - const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new[j0/nwarps])); - kqmax[j0/nwarps] = kqmax_new[j0/nwarps]; - -#pragma unroll - for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F16/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const half2 diff = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] - __half2half2(kqmax[j0/nwarps]); - const half2 val = h2exp(diff); - kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + val; - KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] = val; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale; - } - } - - __syncthreads(); - -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += nwarps) { - const int k = k0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i]; - } - } - - __syncthreads(); - -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += 2) { - half2 V_k[(D/2)/WARP_SIZE][2]; - half2 KQ_k[ncols/nwarps]; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - V_k[i0/WARP_SIZE][0] = KV_tmp[k0 + 0][i]; - V_k[i0/WARP_SIZE][1] = KV_tmp[k0 + 1][i]; - } -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - KQ_k[j0/nwarps] = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + k0/2]; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][0]* __low2half2(KQ_k[j0/nwarps]); - VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][1]*__high2half2(KQ_k[j0/nwarps]); - } - } - } - - __syncthreads(); - } - - //Attention sink: adjust running max and sum once per head - if (sinksf && blockIdx.y == 0) { - const half sink = __float2half(sinksf[head]); - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink); - kqmax_new_j = warp_reduce_max(kqmax_new_j); - - const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j)); - kqmax[j0/nwarps] = kqmax_new_j; - - const half val = hexp(sink - kqmax[j0/nwarps]); - kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale; - if (threadIdx.x == 0) { - kqsum[j0/nwarps].x = __hadd(__low2half(kqsum[j0/nwarps]), val); - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale; - } - } - } - - float2 * dst2 = (float2 *) dst; - -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { - const int j_VKQ = j_VKQ_0 + threadIdx.y; - - if (ic0 + j_VKQ >= ne01) { - return; - } - - half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]); - kqsum_j = warp_reduce_sum((float)kqsum_j); - - const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; - -#pragma unroll - for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) { - const int i0 = i00 + threadIdx.x; - - half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE]; - if (gridDim.y == 1) { - dst_val /= __half2half2(kqsum_j); - } - dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val); - } - - if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); - } - } -#else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, - max_bias, m0, m1, n_head_log2, logit_softcap, - ne00, ne01, ne02, ne03, - nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb11, nb12, nb13, - nb21, nb22, nb23, - ne31, ne32, ne33, - nb31, nb32, nb33); - NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) -} - -template -void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - switch (Q->ne[0]) { - case 64: { - constexpr int D = 64; - constexpr int nwarps = 8; - constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false); - } break; - case 128: { - constexpr int D = 128; - constexpr int nwarps = 8; - constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false); - } break; - default: { - GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); - } break; - } -} - -void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - - const int32_t precision = KQV->op_params[3]; - GGML_ASSERT(precision == GGML_PREC_DEFAULT); - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - if (Q->ne[1] <= 16) { - constexpr int cols_per_block = 16; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f16_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f16_64_128(ctx, dst); - } - return; - } - - constexpr int cols_per_block = 32; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f16_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f16_64_128(ctx, dst); - } -} diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cuh b/ggml/src/ggml-cuda/fattn-tile-f16.cuh deleted file mode 100644 index ffc5878427..0000000000 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cuh +++ /dev/null @@ -1,3 +0,0 @@ -#include "common.cuh" - -void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu deleted file mode 100644 index b96a9ef971..0000000000 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ /dev/null @@ -1,379 +0,0 @@ -#include "common.cuh" -#include "fattn-common.cuh" -#include "fattn-tile-f32.cuh" - -#define FATTN_KQ_STRIDE_TILE_F32 32 - -template // D == head size -#if !defined(GGML_USE_HIP) -__launch_bounds__(nwarps*WARP_SIZE, 2) -#endif // !defined(GGML_USE_HIP) -static __global__ void flash_attn_tile_ext_f32( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, - const int32_t nb01, const int32_t nb02, const int32_t nb03, - const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, - const int32_t nb11, const int32_t nb12, const int64_t nb13, - const int32_t nb21, const int32_t nb22, const int64_t nb23, - const int32_t ne31, const int32_t ne32, const int32_t ne33, - const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#ifdef FLASH_ATTN_AVAILABLE - - // Skip unused kernel variants for faster compilation: -#ifdef FP16_MMA_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // FP16_MMA_AVAILABLE - if (use_logit_softcap && !(D == 128 || D == 256)) { - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, - max_bias, m0, m1, n_head_log2, logit_softcap, - ne00, ne01, ne02, ne03, - nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb11, nb12, nb13, - nb21, nb22, nb23, - ne31, ne32, ne33, - nb31, nb32, nb33); - NO_DEVICE_CODE; - return; - } - - // In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. - - const int sequence = blockIdx.z / ne02; - const int head = blockIdx.z - sequence*ne02; - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float * sinksf = (const float *) (sinks); - - const int stride_KV2 = nb11 / sizeof(half2); - - const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); - - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - - __shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32]; - - __shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts. - float2 * KV_tmp2 = (float2 *) KV_tmp; - - float kqmax[ncols/nwarps]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - kqmax[j0/nwarps] = -FLT_MAX/2.0f; - } - float kqsum[ncols/nwarps] = {0.0f}; - - float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}}; - - // Convert Q to half2 and store in registers: - __shared__ float Q_f[ncols][D]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) { - float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f); - Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale; - Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale; - } - } - - __syncthreads(); - - const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; - for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) { - // Calculate KQ tile and keep track of new maximum KQ values: - - float kqmax_new[ncols/nwarps]; -#pragma unroll - for (int j = 0; j < ncols/nwarps; ++j) { - kqmax_new[j] = kqmax[j]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) { - const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x]; - KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp); - KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp); - } - } - - __syncthreads(); - - float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}}; - -#pragma unroll - for (int k_KQ = 0; k_KQ < D; ++k_KQ) { - float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE]; - float Q_k[ncols/nwarps]; - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) { - const int i_KQ = i_KQ_0 + threadIdx.x; - - K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ]; - } -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; - - Q_k[j_KQ_0/nwarps] = Q_f[j_KQ][k_KQ]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) { -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE] * Q_k[j_KQ_0/nwarps]; - } - } - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) { - const int i_KQ = i_KQ_0 + threadIdx.x; - -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; - - if (use_logit_softcap) { - sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); - } - - sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; - - kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); - - KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F32 + i_KQ] = sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]; - } - } - - __syncthreads(); - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]); - const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]); - kqmax[j0/nwarps] = kqmax_new[j0/nwarps]; - - float kqsum_add = 0.0f; -#pragma unroll - for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F32; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const float diff = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] - kqmax[j0/nwarps]; - const float val = expf(diff); - kqsum_add += val; - KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] = val; - } - kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale; - VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale; - } - } - - __syncthreads(); - -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F32; k0 += nwarps) { - const int k = k0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i]; - KV_tmp2[k*(D/2) + i].x = __low2float(tmp); - KV_tmp2[k*(D/2) + i].y = __high2float(tmp); - } - } - - __syncthreads(); - -#pragma unroll - for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) { - float2 V_k[(D/2)/WARP_SIZE]; - float KQ_k[ncols/nwarps]; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i]; - } -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - KQ_k[j0/nwarps] = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + k]; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps]; - VKQ[j0/nwarps][i0/WARP_SIZE].y += V_k[i0/WARP_SIZE].y*KQ_k[j0/nwarps]; - } - } - } - - __syncthreads(); - } - - - //Attention sink: adjust running max and sum once per head - if (sinksf && blockIdx.y == 0) { - const float sink = sinksf[head]; - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink); - kqmax_new_j = warp_reduce_max(kqmax_new_j); - - const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j); - kqmax[j0/nwarps] = kqmax_new_j; - - const float val = expf(sink - kqmax[j0/nwarps]); - kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale; - if (threadIdx.x == 0) { - kqsum[j0/nwarps] += val; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale; - VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale; - } - } - } - - float2 * dst2 = (float2 *) dst; - -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { - const int j_VKQ = j_VKQ_0 + threadIdx.y; - - if (ic0 + j_VKQ >= ne01) { - return; - } - - float kqsum_j = kqsum[j_VKQ_0/nwarps]; - kqsum_j = warp_reduce_sum(kqsum_j); - - const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; - -#pragma unroll - for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) { - const int i0 = i00 + threadIdx.x; - - float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE]; - if (gridDim.y == 1) { - dst_val.x /= kqsum_j; - dst_val.y /= kqsum_j; - } - dst2[j_dst_unrolled*(D/2) + i0] = dst_val; - } - - if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); - } - } -#else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, - max_bias, m0, m1, n_head_log2, logit_softcap, - ne00, ne01, ne02, ne03, - nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb11, nb12, nb13, - nb21, nb22, nb23, - ne31, ne32, ne33, - nb31, nb32, nb33); - NO_DEVICE_CODE; -#endif // FLASH_ATTN_AVAILABLE -} - -template -void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - switch (Q->ne[0]) { - case 64: { - constexpr int D = 64; - constexpr int nwarps = 8; - constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false); - } break; - case 128: { - constexpr int D = 128; - constexpr int nwarps = 8; - constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false); - } break; - default: { - GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); - } break; - } -} - -void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - if (Q->ne[1] <= 16) { - constexpr int cols_per_block = 16; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f32_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f32_64_128(ctx, dst); - } - return; - } - - constexpr int cols_per_block = 32; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f32_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f32_64_128(ctx, dst); - } -} diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cuh b/ggml/src/ggml-cuda/fattn-tile-f32.cuh deleted file mode 100644 index b1c546c805..0000000000 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cuh +++ /dev/null @@ -1,3 +0,0 @@ -#include "common.cuh" - -void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu new file mode 100644 index 0000000000..fb2163acd0 --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -0,0 +1,596 @@ +#include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-tile.cuh" + +#define FATTN_TILE_NTHREADS 256 + +static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) { + if (GGML_CUDA_CC_IS_AMD(cc)) { + switch (D) { + case 64: + return ncols <= 16 ? 32 : 64; + case 128: + return ncols <= 16 ? 64 : warp_size; + case 256: + return 64; + default: + GGML_ABORT("fatal error"); + return -1; + } + } + if (fast_fp16_available(cc)) { + switch (D) { + case 64: + case 128: + return 128; + case 256: + return ncols <= 16 ? 128 : 64; + default: + GGML_ABORT("fatal error"); + return -1; + } + } + switch (D) { + case 64: + return ncols <= 16 ? 128 : 64; + case 128: + return ncols <= 16 ? 64 : 32; + case 256: + return 32; + default: + GGML_ABORT("fatal error"); + return -1; + } +} + +static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) { +#ifdef GGML_USE_HIP + switch (D) { + case 64: + return ncols <= 16 ? 32 : 64; + case 128: + return ncols <= 16 ? 64 : warp_size; + case 256: + return 64; + default: + return -1; + } +#else +#ifdef FAST_FP16_AVAILABLE + switch (D) { + case 64: + case 128: + return 128; + case 256: + return ncols <= 16 ? 128 : 64; + default: + return -1; + } +#else + switch (D) { + case 64: + return ncols <= 16 ? 128 : 64; + case 128: + return ncols <= 16 ? 64 : 32; + case 256: + return 32; + default: + return -1; + } +#endif // FAST_FP16_AVAILABLE +#endif // GGML_USE_HIP + GGML_UNUSED_VARS(ncols, warp_size); +} + +static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols, int warp_size) { +#ifdef GGML_USE_HIP + switch (D) { + case 64: + return 64; + case 128: + return ncols <= 16 ? 2*warp_size : 128; + case 256: + return ncols <= 16 ? 128 : 2*warp_size; + default: + return -1; + } +#else +#ifdef FAST_FP16_AVAILABLE + switch (D) { + case 64: + return 64; + case 128: + return ncols <= 16 ? 128 : 64; + case 256: + return ncols <= 16 ? 64 : 128; + default: + return -1; + } +#else + switch (D) { + case 64: + return 64; + case 128: + return 128; + case 256: + return ncols <= 16 ? 128 : 64; + default: + return -1; + } +#endif // FAST_FP16_AVAILABLE +#endif // GGML_USE_HIP + GGML_UNUSED_VARS(ncols, warp_size); +} + +template // D == head size +#ifdef GGML_USE_HIP +__launch_bounds__(FATTN_TILE_NTHREADS, 1) +#else +__launch_bounds__(FATTN_TILE_NTHREADS, 2) +#endif // GGML_USE_HIP +static __global__ void flash_attn_tile( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int * __restrict__ KV_max, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { +#ifdef FLASH_ATTN_AVAILABLE + + // Skip unused kernel variants for faster compilation: +#ifdef FP16_MMA_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FP16_MMA_AVAILABLE + + if (use_logit_softcap && !(D == 128 || D == 256)) { + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; + return; + } + + constexpr int warp_size = 32; + constexpr int nwarps = FATTN_TILE_NTHREADS / warp_size; + constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size); + static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size."); + constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size); + static_assert(kq_nbatch % (2*warp_size) == 0, "bad kq_nbatch"); + + // In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); + const float * sinksf = (const float *) (sinks); + + const int stride_KV2 = nb11 / sizeof(half2); + + const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); + + __shared__ float KQ[ncols][kq_stride]; +#ifdef FAST_FP16_AVAILABLE + __shared__ half2 Q_tmp[ncols][D/2]; + __shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + 1)]; // Padded to avoid memory bank conflicts. + half2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; +#else + __shared__ float Q_tmp[ncols][D]; + __shared__ float KV_tmp_f[kq_stride * (kq_nbatch + 1)]; // Padded to avoid memory bank conflicts. + float2 * KV_tmp_f2 = (float2 *) KV_tmp_f; + float2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; +#endif // FAST_FP16_AVAILABLE + + + float kqmax[ncols/nwarps]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + kqmax[j0/nwarps] = -FLT_MAX/2.0f; + } + float kqsum[ncols/nwarps] = {0.0f}; + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0 + threadIdx.x] : make_float2(0.0f, 0.0f); +#ifdef FAST_FP16_AVAILABLE + Q_tmp[j][i0 + threadIdx.x] = make_half2(tmp.x * scale, tmp.y * scale); +#else + Q_tmp[j][2*i0 + threadIdx.x] = tmp.x * scale; + Q_tmp[j][2*i0 + warp_size + threadIdx.x] = tmp.y * scale; +#endif // FAST_FP16_AVAILABLE + } + } + + __syncthreads(); + + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) { + // Calculate KQ tile and keep track of new maximum KQ values: + + float kqmax_new[ncols/nwarps]; +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + kqmax_new[j] = kqmax[j]; + } + + float sum[kq_stride/warp_size][ncols/nwarps] = {{0.0f}}; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) { +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) { + const int i_KQ = i_KQ_0 + threadIdx.y; + +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size) { + const half2 tmp_h2 = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x]; +#ifdef FAST_FP16_AVAILABLE + KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1) + k_KQ_1 + threadIdx.x] = tmp_h2; +#else + const float2 tmp_f2 = __half22float2(tmp_h2); + KV_tmp_f[i_KQ*(kq_nbatch + 1) + 2*k_KQ_1 + threadIdx.x] = tmp_f2.x; + KV_tmp_f[i_KQ*(kq_nbatch + 1) + 2*k_KQ_1 + warp_size + threadIdx.x] = tmp_f2.y; +#endif // FAST_FP16_AVAILABLE + } + } + + __syncthreads(); + +#ifdef FAST_FP16_AVAILABLE +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; ++k_KQ_1) { + half2 K_k[kq_stride/warp_size]; + half2 Q_k[ncols/nwarps]; +#else +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; ++k_KQ_1) { + float K_k[kq_stride/warp_size]; + float Q_k[ncols/nwarps]; +#endif // FAST_FP16_AVAILABLE + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { + const int i_KQ = i_KQ_0 + threadIdx.x; + +#ifdef FAST_FP16_AVAILABLE + K_k[i_KQ_0/warp_size] = KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1) + k_KQ_1]; +#else + K_k[i_KQ_0/warp_size] = KV_tmp_f [i_KQ*(kq_nbatch + 1) + k_KQ_1]; +#endif // FAST_FP16_AVAILABLE + } +#pragma unroll + for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { + const int j_KQ = j_KQ_0 + threadIdx.y; + +#ifdef FAST_FP16_AVAILABLE + Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]; +#else + Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]; +#endif // FAST_FP16_AVAILABLE + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { +#pragma unroll + for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { +#ifdef FAST_FP16_AVAILABLE + const float2 tmp = __half22float2(K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps]); + sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += tmp.x + tmp.y; +#else + sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps]; +#endif // FAST_FP16_AVAILABLE + } + } + } + + if (k_KQ_0 + kq_nbatch < D) { + __syncthreads(); // Sync not needed on last iteration. + } + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { + const int i_KQ = i_KQ_0 + threadIdx.x; + +#pragma unroll + for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { + const int j_KQ = j_KQ_0 + threadIdx.y; + + if (use_logit_softcap) { + sum[i_KQ_0/warp_size][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/warp_size][j_KQ_0/nwarps]); + } + + sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; + + kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/warp_size][j_KQ_0/nwarps]); + + KQ[j_KQ][i_KQ] = sum[i_KQ_0/warp_size][j_KQ_0/nwarps]; + } + } + + __syncthreads(); + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]); + const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]); + kqmax[j0/nwarps] = kqmax_new[j0/nwarps]; + + float kqsum_add = 0.0f; +#pragma unroll + for (int i0 = 0; i0 < kq_stride; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const float diff = KQ[j][i] - kqmax[j0/nwarps]; + const float val = expf(diff); + kqsum_add += val; + KQ[j][i] = val; + } + kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add; + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale; + VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + + constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; + static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter"); +#pragma unroll + for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) { +#pragma unroll + for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) { + const int k_tile = k1 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const half2 tmp = V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i]; +#ifdef FAST_FP16_AVAILABLE + KV_tmp_h2[k_tile*(D/2) + i] = tmp; +#else + KV_tmp_f2[k_tile*(D/2) + i] = __half22float2(tmp); +#endif // FAST_FP16_AVAILABLE + } + } + + __syncthreads(); + +#pragma unroll + for (int k1 = 0; k1 < V_cols_per_iter; ++k1) { +#ifdef FAST_FP16_AVAILABLE + half2 V_k[(D/2)/warp_size]; + half2 KQ_k[ncols/nwarps]; +#else + float2 V_k[(D/2)/warp_size]; + float KQ_k[ncols/nwarps]; +#endif // FAST_FP16_AVAILABLE + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + const int i = i0 + threadIdx.x; + +#ifdef FAST_FP16_AVAILABLE + V_k[i0/warp_size] = KV_tmp_h2[k1*(D/2) + i]; +#else + V_k[i0/warp_size] = KV_tmp_f2[k1*(D/2) + i]; +#endif // FAST_FP16_AVAILABLE + } +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#ifdef FAST_FP16_AVAILABLE + const float tmp = KQ[j][k0 + k1]; + KQ_k[j0/nwarps] = make_half2(tmp, tmp); +#else + KQ_k[j0/nwarps] = KQ[j][k0 + k1]; +#endif // FAST_FP16_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { +#ifdef FAST_FP16_AVAILABLE + VKQ[j0/nwarps][i0/warp_size] += V_k[i0/warp_size] *KQ_k[j0/nwarps]; +#else + VKQ[j0/nwarps][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0/nwarps]; + VKQ[j0/nwarps][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0/nwarps]; +#endif // FAST_FP16_AVAILABLE + } + } + } + + __syncthreads(); + } + } + + + // Attention sink: adjust running max and sum once per head + if (sinksf && blockIdx.y == 0) { + const float sink = sinksf[head]; + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink); + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j); + kqmax[j0/nwarps] = kqmax_new_j; + + const float val = expf(sink - kqmax[j0/nwarps]); + kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale; + if (threadIdx.x == 0) { + kqsum[j0/nwarps] += val; + } + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale; + VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + } + + float2 * dst2 = (float2 *) dst; + +#pragma unroll + for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { + const int j_VKQ = j_VKQ_0 + threadIdx.y; + + if (ic0 + j_VKQ >= ne01) { + return; + } + + float kqsum_j = kqsum[j_VKQ_0/nwarps]; + kqsum_j = warp_reduce_sum(kqsum_j); + + const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + +#pragma unroll + for (int i00 = 0; i00 < D/2; i00 += warp_size) { + const int i0 = i00 + threadIdx.x; + +#ifdef FAST_FP16_AVAILABLE + float2 dst_val = __half22float2(VKQ[j_VKQ_0/nwarps][i0/warp_size]); +#else + float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/warp_size]; +#endif // FAST_FP16_AVAILABLE + + if (gridDim.y == 1) { + dst_val.x /= kqsum_j; + dst_val.y /= kqsum_j; + } + dst2[j_dst_unrolled*(D/2) + i0] = dst_val; + } + + if (gridDim.y != 1 && threadIdx.x == 0) { + dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); + } + } +#else + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; +#endif // FLASH_ATTN_AVAILABLE +} + +template +static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int warp_size = 32; + const int nwarps = FATTN_TILE_NTHREADS / warp_size; + + constexpr size_t nbytes_shared = 0; + + if (Q->ne[1] > 16) { + constexpr int cols_per_block = 32; + fattn_kernel_t fattn_kernel = flash_attn_tile; + const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); + return; + } + + constexpr int cols_per_block = 16; + fattn_kernel_t fattn_kernel = flash_attn_tile; + const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); +} + +template +static void launch_fattn_tile_switch_head_size(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + switch (Q->ne[0]) { + case 64: { + launch_fattn_tile_switch_ncols< 64, use_logit_softcap>(ctx, dst); + } break; + case 128: { + launch_fattn_tile_switch_ncols<128, use_logit_softcap>(ctx, dst); + } break; + case 256: { + launch_fattn_tile_switch_ncols<256, use_logit_softcap>(ctx, dst); + } break; + default: { + GGML_ABORT("Unsupported head size"); + } break; + } +} + +void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_switch_head_size(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_switch_head_size(ctx, dst); + } +} diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh new file mode 100644 index 0000000000..10dc22d1bf --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 4883427266..7626d89ca0 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -1,8 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" #include "fattn-mma-f16.cuh" -#include "fattn-tile-f16.cuh" -#include "fattn-tile-f32.cuh" +#include "fattn-tile.cuh" #include "fattn-vec-f16.cuh" #include "fattn-vec-f32.cuh" #include "fattn-wmma-f16.cuh" @@ -271,8 +270,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg // Best FlashAttention kernel for a specific GPU: enum best_fattn_kernel { BEST_FATTN_KERNEL_NONE = 0, - BEST_FATTN_KERNEL_TILE_F32 = 200, - BEST_FATTN_KERNEL_TILE_F16 = 210, + BEST_FATTN_KERNEL_TILE = 200, BEST_FATTN_KERNEL_VEC_F32 = 100, BEST_FATTN_KERNEL_VEC_F16 = 110, BEST_FATTN_KERNEL_WMMA_F16 = 300, @@ -411,10 +409,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes: - if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { - return BEST_FATTN_KERNEL_TILE_F16; - } - return BEST_FATTN_KERNEL_TILE_F32; + return BEST_FATTN_KERNEL_TILE; } void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -422,11 +417,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) { case BEST_FATTN_KERNEL_NONE: GGML_ABORT("fatal error"); - case BEST_FATTN_KERNEL_TILE_F32: - ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); - break; - case BEST_FATTN_KERNEL_TILE_F16: - ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + case BEST_FATTN_KERNEL_TILE: + ggml_cuda_flash_attn_ext_tile(ctx, dst); break; case BEST_FATTN_KERNEL_VEC_F32: ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); From 3b15924d71237a43bb5ad71f5b885ee66a821342 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Sun, 7 Sep 2025 10:19:45 +0200 Subject: [PATCH 04/46] ggml WebGPU: remove userdata from request adapter callback (#15527) * ggml WebGPU: remove userdata from request adapter callback This commit removes the `userdata` parameter from the WebGPU request adapter callback in `ggml-webgpu.cpp`. Instead, the lambda function captures the `webgpu_context` directly. The motivation for this change is to simplify the code and improve readability. * inline the callback lambda into the RequestAdapter call This commit removes the callback lambda variable and inlines it directly into the RequestAdapter call. --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index e5df883c13..aad27bf605 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1154,17 +1154,15 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t webgpu_context ctx = reg_ctx->webgpu_ctx; wgpu::RequestAdapterOptions options = {}; - auto callback = - [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message, void * userdata) { - if (status != wgpu::RequestAdapterStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); - return; - } - *static_cast(userdata) = std::move(adapter); - }; - void * userdata = &ctx->adapter; ctx->instance.WaitAny( - ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous, callback, userdata), UINT64_MAX); + ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous, + [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { + if (status != wgpu::RequestAdapterStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); + return; + } + ctx->adapter = std::move(adapter); + }), UINT64_MAX); GGML_ASSERT(ctx->adapter != nullptr); ctx->adapter.GetLimits(&ctx->limits); From 267e99867f09bec8bcc2e424ad9bcddd6cccf9d0 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sun, 7 Sep 2025 11:53:07 -0500 Subject: [PATCH 05/46] vulkan: Use larger loads in scalar/coopmat1 matmul (#15729) I think glslang will translate an access like x[i][1].z to OpAccessChain ... x, i, 1, 2 OpLoad float16_t ... rather than loading all of x[i] in a single OpLoad. Change the code to explicitly load the vector/matrix. --- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 60 +++++++++++-------- .../src/ggml-vulkan/vulkan-shaders/types.comp | 11 ++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 20 +++---- 3 files changed, 57 insertions(+), 34 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 7e10e99e9e..f6a7761ffa 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -315,21 +315,23 @@ void main() { #if LOAD_VEC_A == 8 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x); - buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y); - buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z); - buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w); - buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x); - buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y); - buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z); - buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w); + A_TYPE32 aa = A_TYPE32(data_a[idx]); + buf_a[buf_idx ] = FLOAT_TYPE(aa[0].x); + buf_a[buf_idx + 1] = FLOAT_TYPE(aa[0].y); + buf_a[buf_idx + 2] = FLOAT_TYPE(aa[0].z); + buf_a[buf_idx + 3] = FLOAT_TYPE(aa[0].w); + buf_a[buf_idx + 4] = FLOAT_TYPE(aa[1].x); + buf_a[buf_idx + 5] = FLOAT_TYPE(aa[1].y); + buf_a[buf_idx + 6] = FLOAT_TYPE(aa[1].z); + buf_a[buf_idx + 7] = FLOAT_TYPE(aa[1].w); #elif LOAD_VEC_A == 4 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x); - buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y); - buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z); - buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w); + A_TYPE32 aa = A_TYPE32(data_a[idx]); + buf_a[buf_idx ] = FLOAT_TYPE(aa.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(aa.y); + buf_a[buf_idx + 2] = FLOAT_TYPE(aa.z); + buf_a[buf_idx + 3] = FLOAT_TYPE(aa.w); #else if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); @@ -808,14 +810,19 @@ void main() { const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; #endif const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; - buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); - buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); - buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z); - buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w); - buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x); - buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y); - buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z); - buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w); +#if defined(DATA_B_BF16) + B_TYPE32 bb = TO_FLOAT_TYPE(data_b[idx]); +#else + B_TYPE32 bb = B_TYPE32(data_b[idx]); +#endif + buf_b[buf_idx + 0] = FLOAT_TYPE(bb[0].x); + buf_b[buf_idx + 1] = FLOAT_TYPE(bb[0].y); + buf_b[buf_idx + 2] = FLOAT_TYPE(bb[0].z); + buf_b[buf_idx + 3] = FLOAT_TYPE(bb[0].w); + buf_b[buf_idx + 4] = FLOAT_TYPE(bb[1].x); + buf_b[buf_idx + 5] = FLOAT_TYPE(bb[1].y); + buf_b[buf_idx + 6] = FLOAT_TYPE(bb[1].z); + buf_b[buf_idx + 7] = FLOAT_TYPE(bb[1].w); #elif LOAD_VEC_B == 4 #ifdef MUL_MAT_ID const u16vec2 row_idx = row_ids[loadc_b + l]; @@ -824,10 +831,15 @@ void main() { const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; #endif const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; - buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x); - buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y); - buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z); - buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w); +#if defined(DATA_B_BF16) + B_TYPE32 bb = TO_FLOAT_TYPE(data_b[idx]); +#else + B_TYPE32 bb = B_TYPE32(data_b[idx]); +#endif + buf_b[buf_idx + 0] = FLOAT_TYPE(bb.x); + buf_b[buf_idx + 1] = FLOAT_TYPE(bb.y); + buf_b[buf_idx + 2] = FLOAT_TYPE(bb.z); + buf_b[buf_idx + 3] = FLOAT_TYPE(bb.w); #elif !MUL_MAT_ID if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index 408722c878..c2acc803f6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -13,10 +13,13 @@ #if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 #define A_TYPE float +#define A_TYPE32 float #elif LOAD_VEC_A == 4 #define A_TYPE vec4 +#define A_TYPE32 vec4 #elif LOAD_VEC_A == 8 #define A_TYPE mat2x4 +#define A_TYPE32 mat2x4 #endif #endif @@ -26,10 +29,13 @@ #if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 #define A_TYPE float16_t +#define A_TYPE32 float #elif LOAD_VEC_A == 4 #define A_TYPE f16vec4 +#define A_TYPE32 vec4 #elif LOAD_VEC_A == 8 #define A_TYPE f16mat2x4 +#define A_TYPE32 mat2x4 #endif #endif @@ -1424,6 +1430,11 @@ float bf16_to_fp32(uint32_t u) return uintBitsToFloat(u << 16); } +vec4 bf16_to_fp32(uvec4 u) +{ + return vec4(bf16_to_fp32(u.x), bf16_to_fp32(u.y), bf16_to_fp32(u.z), bf16_to_fp32(u.w)); +} + float e8m0_to_fp32(uint8_t x) { uint32_t bits; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 613498d0d5..93cdfd09a9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -364,11 +364,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c }; // Shaders with f16 B_TYPE - string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); // bf16 { @@ -384,8 +384,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c if (!(coopmat || coopmat2)) #endif { - string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPE32", "vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); } } @@ -408,13 +408,13 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c // don't generate f32 variants for coopmat2 if (!coopmat2) { - string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } if (tname != "f16" && tname != "f32") { - string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) From c97b5e5854b47b18a248d77edb693c63018a0865 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sun, 7 Sep 2025 12:00:49 -0500 Subject: [PATCH 06/46] vulkan: Support pad_ext (#15794) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 71 ++++++++++++++++++-- ggml/src/ggml-vulkan/vulkan-shaders/pad.comp | 27 +++++++- tests/test-backend-ops.cpp | 18 ++++- 3 files changed, 104 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index cd1c66ba7b..5185ede5d6 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -803,6 +803,57 @@ static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_ten p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize); p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize); + return p; // offsets are initialized later in ggml_vk_op +} + +struct vk_op_pad_push_constants { + uint32_t ne; + uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t misalign_offsets; + + uint32_t lp0; uint32_t rp0; + uint32_t lp1; uint32_t rp1; + uint32_t lp2; uint32_t rp2; + uint32_t lp3; uint32_t rp3; +}; + +static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst) { + int64_t ne = ggml_nelements(dst); + GGML_ASSERT(ne <= (int64_t)std::numeric_limits::max()); + + vk_op_pad_push_constants p{}; + p.ne = (uint32_t)ne; + + size_t src0_tsize = ggml_type_size(src0->type); + p.ne00 = (uint32_t)src0->ne[0]; + p.ne01 = (uint32_t)src0->ne[1]; + p.ne02 = (uint32_t)src0->ne[2]; + p.ne03 = (uint32_t)src0->ne[3]; + p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize); + p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize); + p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize); + p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize); + + size_t dst_tsize = ggml_type_size(dst->type); + p.ne10 = (uint32_t)dst->ne[0]; + p.ne11 = (uint32_t)dst->ne[1]; + p.ne12 = (uint32_t)dst->ne[2]; + p.ne13 = (uint32_t)dst->ne[3]; + p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize); + p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize); + p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize); + p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize); + + p.lp0 = dst->op_params[0]; + p.rp0 = dst->op_params[1]; + p.lp1 = dst->op_params[2]; + p.rp1 = dst->op_params[3]; + p.lp2 = dst->op_params[4]; + p.rp2 = dst->op_params[5]; + p.lp3 = dst->op_params[6]; + p.rp3 = dst->op_params[7]; + return p; // fastdiv values and offsets are initialized later in ggml_vk_op } @@ -3250,7 +3301,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -7829,6 +7880,16 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src2); } +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); @@ -8771,7 +8832,7 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, con } static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); + vk_op_pad_push_constants p = vk_op_pad_push_constants_init(src0, dst); ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun); } @@ -12076,10 +12137,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_ACC: case GGML_OP_CONCAT: case GGML_OP_SCALE: - return true; case GGML_OP_PAD: - return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) && - (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); case GGML_OP_ROLL: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: @@ -12520,7 +12578,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * const float * params = (const float *)tensor->op_params; tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); } else if (tensor->op == GGML_OP_PAD) { - tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]); + tensor_clone = ggml_pad_ext(ggml_ctx, src_clone[0], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3], + tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]); } else if (tensor->op == GGML_OP_REPEAT) { tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor); } else if (tensor->op == GGML_OP_REPEAT_BACK) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp b/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp index 450b67fc55..0d81220c71 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp @@ -1,7 +1,25 @@ #version 450 #include "types.comp" -#include "generic_unary_head.comp" + +layout (push_constant) uniform parameter +{ + uint ne; + uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; + uint misalign_offsets; + + uint lp0; uint rp0; + uint lp1; uint rp1; + uint lp2; uint rp2; + uint lp3; uint rp3; +} p; + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; @@ -19,10 +37,13 @@ void main() { const uint i1 = (idx - i3_offset - i2_offset) / p.ne10; const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10; - const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; + const uint src0_idx = (i3 - p.lp3)*p.nb03 + (i2 - p.lp2)*p.nb02 + (i1 - p.lp1)*p.nb01 + (i0 - p.lp0)*p.nb00; const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10; - const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; + const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 && + i1 >= p.lp1 && i1 < p.ne11 - p.rp1 && + i2 >= p.lp2 && i2 < p.ne12 - p.rp2 && + i3 >= p.lp3 && i3 < p.ne13 - p.rp3; data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f); } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d638a96ee9..f98a1d2eb5 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4704,21 +4704,28 @@ struct test_pad_ext : public test_case { const int rp2; const int lp3; const int rp3; + const bool v; std::string vars() override { - return VARS_TO_STR10(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); + return VARS_TO_STR11(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, v); } test_pad_ext(ggml_type type = GGML_TYPE_F32, std::array ne_a = {512, 512, 3, 1}, int lp0 = 1, int rp0 = 1, int lp1 = 1, int rp1 = 1, - int lp2 = 1, int rp2 = 1, int lp3 = 1, int rp3 = 1) - : type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3) {} + int lp2 = 1, int rp2 = 1, int lp3 = 1, int rp3 = 1, + bool v = false) + : type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3), v(v) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); ggml_set_name(a, "a"); + if (v) { + a = ggml_view_4d(ctx, a, (a->ne[0] + 1) / 2, (a->ne[1] + 1) / 2, (a->ne[2] + 1) / 2, (a->ne[3] + 1) / 2, a->nb[1], a->nb[2], a->nb[3], 0); + ggml_set_name(a, "view of a"); + } + ggml_tensor * out = ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); ggml_set_name(out, "out"); @@ -6459,6 +6466,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_leaky_relu()); + for (bool v : {false, true}) { + test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v)); + test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v)); + } + for (int hsk : { 40, 64, 80, 128, 192, 256, 576 }) { for (int hsv : { 40, 64, 80, 128, 192, 256, 512 }) { if (hsk != 192 && hsk != 576 && hsk != hsv) continue; From d36e61c580bf7fc7879c443c542312a42b718e11 Mon Sep 17 00:00:00 2001 From: Aaron Teo Date: Mon, 8 Sep 2025 02:18:28 +0800 Subject: [PATCH 07/46] ggml-cpu: clean up s390x SIMD (#15855) * ggml-cpu: clean up s390x simd Signed-off-by: Aaron Teo (cherry picked from commit 0da4b6aa07d96b758812d17b2c82267632fa4ba5) Signed-off-by: Aaron Teo * ggml-cpu: fix hsum data types Signed-off-by: Aaron Teo --------- Signed-off-by: Aaron Teo --- ggml/src/ggml-cpu/arch/s390/quants.c | 116 +++++++++++++-------------- ggml/src/ggml-cpu/ggml-cpu-impl.h | 7 +- 2 files changed, 63 insertions(+), 60 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/s390/quants.c b/ggml/src/ggml-cpu/arch/s390/quants.c index 1c8176fb4d..dc1bba3a3e 100644 --- a/ggml/src/ggml-cpu/arch/s390/quants.c +++ b/ggml/src/ggml-cpu/arch/s390/quants.c @@ -53,9 +53,9 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i #if defined(__VXE__) || defined(__VXE2__) for (int i = 0; i < nb; i++) { - __vector float srcv [8]; - __vector float asrcv[8]; - __vector float amaxv[8]; + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); @@ -74,8 +74,8 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i y[i].d = GGML_CPU_FP32_TO_FP16(d); for (int j = 0; j < 8; j++) { - const __vector float v = vec_mul(srcv[j], vec_splats(id)); - const __vector int32_t vi = vec_signed(v); + const float32x4_t v = vec_mul(srcv[j], vec_splats(id)); + const int32x4_t vi = vec_signed(v); y[i].qs[4*j + 0] = vec_extract(vi, 0); y[i].qs[4*j + 1] = vec_extract(vi, 1); @@ -98,9 +98,9 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i #if defined(__VXE__) || defined(__VXE2__) for (int i = 0; i < nb; i++) { - __vector float srcv [8]; - __vector float asrcv[8]; - __vector float amaxv[8]; + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); @@ -118,11 +118,11 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i y[i].d = GGML_CPU_FP32_TO_FP16(d); - __vector int32_t acc = vec_splats(0); + int32x4_t acc = vec_splats(0); for (int j = 0; j < 8; j++) { - const __vector float v = vec_mul(srcv[j], vec_splats(id)); - const __vector int32_t vi = vec_signed(v); + const float32x4_t v = vec_mul(srcv[j], vec_splats(id)); + const int32x4_t vi = vec_signed(v); y[i].qs[4*j + 0] = vec_extract(vi, 0); y[i].qs[4*j + 1] = vec_extract(vi, 1); @@ -162,37 +162,36 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi float sumf = 0; #if defined(__VXE__) || defined(__VXE2__) - __vector float acc = vec_splats(0.0f); + float32x4_t acc = vec_splats(0.0f); - const __vector uint8_t v_m = vec_splats((const uint8_t)0x0F); - const __vector int8_t v_s = vec_splats( (const int8_t)0x08); + const uint8x16_t v_m = vec_splats((const uint8_t)0x0F); + const int8x16_t v_s = vec_splats( (const int8_t)0x08); for (; ib < nb; ++ib) { - const __vector uint8_t v_x = vec_xl(0, x[ib].qs); - const __vector int8_t v_xl = (const __vector int8_t)(v_x & v_m); - const __vector int8_t v_xh = (const __vector int8_t)(v_x >> 4); + const uint8x16_t v_x = vec_xl(0, x[ib].qs); + const int8x16_t v_xl = (const int8x16_t)(v_x & v_m); + const int8x16_t v_xh = (const int8x16_t)(v_x >> 4); - const __vector int8_t v_xls = vec_sub(v_xl, v_s); - const __vector int8_t v_xhs = vec_sub(v_xh, v_s); + const int8x16_t v_xls = vec_sub(v_xl, v_s); + const int8x16_t v_xhs = vec_sub(v_xh, v_s); - const __vector int8_t v_yl = vec_xl(0 , y[ib].qs); - const __vector int8_t v_yh = vec_xl(QK8_0/2, y[ib].qs); + const int8x16_t v_yl = vec_xl(0 , y[ib].qs); + const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs); - const __vector int16_t v_xylso = vec_mulo(v_xls, v_yl); - const __vector int16_t v_xylse = vec_mule(v_xls, v_yl); - const __vector int16_t v_xyhso = vec_mulo(v_xhs, v_yh); - const __vector int16_t v_xyhse = vec_mule(v_xhs, v_yh); + const int16x8_t v_xylso = vec_mulo(v_xls, v_yl); + const int16x8_t v_xylse = vec_mule(v_xls, v_yl); + const int16x8_t v_xyhso = vec_mulo(v_xhs, v_yh); + const int16x8_t v_xyhse = vec_mule(v_xhs, v_yh); - __vector int16_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_); + int16x8_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_); - const __vector float v_xy = vec_float(vec_unpackh(v_xy_)); - const __vector float v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d)); + const float32x4_t v_xy = vec_float(vec_unpackh(v_xy_)); + const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d)); acc = vec_madd(v_xy, v_d, acc); } - sumf = acc[0] + acc[1] + acc[2] + acc[3]; - + sumf = vec_hsum_f32x4(acc); *s = sumf; #else UNUSED(nb); @@ -249,8 +248,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi acc = vec_madd(v_xy, v_d, acc); } - sumf = acc[0] + acc[1] + acc[2] + acc[3] + summs; - + sumf = vec_hsum_f32x4(acc) + summs; *s = sumf; #else UNUSED(nb); @@ -351,7 +349,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1); } - sumf += vec_hsum(v_sum0) + vec_hsum(v_sum1); + sumf += vec_hsum_f32x4(v_sum0) + vec_hsum_f32x4(v_sum1); #pragma GCC unroll 4 for (; ib < nb; ++ib) { @@ -390,7 +388,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d)); const float32x4_t v_acc = vec_madd(v_xyf, v_d, vec_splats(0.0f)); - sumf += vec_hsum(v_acc); + sumf += vec_hsum_f32x4(v_acc); } *s = sumf; @@ -502,7 +500,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1); } - sumf += vec_hsum(v_sum0) + vec_hsum(v_sum1) + summs0 + summs1; + sumf += vec_hsum_f32x4(v_sum0) + vec_hsum_f32x4(v_sum1) + summs0 + summs1; #pragma GCC unroll 4 for (; ib < nb; ++ib) { @@ -543,7 +541,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d)); const float32x4_t v_acc = vec_madd(v_xyf, v_d, v_acc); - sumf += vec_hsum(v_acc) + summs; + sumf += vec_hsum_f32x4(v_acc) + summs; } *s = sumf; @@ -575,7 +573,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi float sumf = 0; #if defined(__VXE__) || defined(__VXE2__) - __vector float acc = vec_splats(0.0f); + float32x4_t acc = vec_splats(0.0f); #pragma GCC unroll 8 for (; ib < nb; ++ib) { @@ -594,7 +592,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi acc = vec_madd(v_xy, v_d, acc); } - sumf = acc[0] + acc[1] + acc[2] + acc[3]; + sumf = vec_hsum_f32x4(acc); *s = sumf; #else @@ -718,10 +716,10 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[6]); isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[7]); - isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0]; - isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1]; - isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2]; - isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3]; + isum += vec_hsum_i32x4(isum0) * scale[0]; + isum += vec_hsum_i32x4(isum1) * scale[1]; + isum += vec_hsum_i32x4(isum2) * scale[2]; + isum += vec_hsum_i32x4(isum3) * scale[3]; scale += 4; @@ -819,7 +817,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi v_xl[1] = (int8x16_t)vec_and(v_x[1], v_lm); const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]); - sumi1 += (p1[0] + p1[1] + p1[2] + p1[3]) * scales[2*j+0]; + sumi1 += vec_hsum_i32x4(p1) * scales[2*j+0]; v_y[0] = vec_xl(0 , y0); v_y[1] = vec_xl(16, y0); @@ -829,7 +827,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi v_xl[1] = (int8x16_t)vec_sr(v_x[1], 4); const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]); - sumi2 += (p2[0] + p2[1] + p2[2] + p2[3]) * scales[2*j+1]; + sumi2 += vec_hsum_i32x4(p2) * scales[2*j+1]; } sumf += d * (sumi1 + sumi2); @@ -911,7 +909,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh); const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh); const int32x4_t v_mins = vec_add(v_minsho, v_minshe); - const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]; + const int32_t mins = vec_hsum_i32x4(v_mins); const uint8_t * scales = (const uint8_t *)utmp; const uint8_t * GGML_RESTRICT x0l = x[i].qs; @@ -948,8 +946,8 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi int32x4_t sumi0 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[0], v_y[0]), q5b[1], v_y[1]); int32x4_t sumi1 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[2], v_y[2]), q5b[3], v_y[3]); - sumi += (sumi0[0] + sumi0[1] + sumi0[2] + sumi0[3]) * *scales++; - sumi += (sumi1[0] + sumi1[1] + sumi1[2] + sumi1[3]) * *scales++; + sumi += vec_hsum_i32x4(sumi0) * *scales++; + sumi += vec_hsum_i32x4(sumi1) * *scales++; } sumf += d * sumi - dmin * mins; @@ -1020,7 +1018,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh); const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe; - const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]; + const int32_t mins = vec_hsum_i32x4(v_mins); int32_t isum = 0; for (int j = 0; j < QK_K/128; ++j) { @@ -1060,10 +1058,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi int32x4_t summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]); int32x4_t summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]); - isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] + - (summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] + - (summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] + - (summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3]; + isum += vec_hsum_i32x4(summs0) * scale[0] + + vec_hsum_i32x4(summs1) * scale[1] + + vec_hsum_i32x4(summs2) * scale[2] + + vec_hsum_i32x4(summs3) * scale[3]; scale += 4; @@ -1094,10 +1092,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]); summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]); - isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] + - (summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] + - (summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] + - (summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3]; + isum += vec_hsum_i32x4(summs0) * scale[0] + + vec_hsum_i32x4(summs1) * scale[1] + + vec_hsum_i32x4(summs2) * scale[2] + + vec_hsum_i32x4(summs3) * scale[3]; scale += 4; } @@ -1285,7 +1283,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs); const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh); - sumf += GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d) * (v_xy[0] + v_xy[1] + v_xy[2] + v_xy[3]); + sumf += GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d) * vec_hsum_i32x4(v_xy); } *s = sumf; @@ -1354,8 +1352,8 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v h >>= 4; - sumi1 += (vsumi0[0] + vsumi0[1] + vsumi0[2] + vsumi0[3]) * ls1; - sumi2 += (vsumi1[0] + vsumi1[1] + vsumi1[2] + vsumi1[3]) * ls2; + sumi1 += vec_hsum_i32x4(vsumi0) * ls1; + sumi2 += vec_hsum_i32x4(vsumi1) * ls2; } sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index cd055e75cb..799e2b1187 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -483,11 +483,16 @@ inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) { /** * @see https://github.com/ggml-org/llama.cpp/pull/14037 */ -inline static float vec_hsum(float32x4_t v) { +inline static float vec_hsum_f32x4(float32x4_t v) { float32x4_t v_temp = v + vec_reve(v); return v_temp[0] + v_temp[1]; } +inline static int32_t vec_hsum_i32x4(int32x4_t v) { + int32x4_t v_temp = v + vec_reve(v); + return v_temp[0] + v_temp[1]; +} + inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) { const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b); return acc + (vec_unpackh(p) + vec_unpackl(p)); From 3976dfbe00f02a62c0deca32c46138e4f0ca81d8 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sun, 7 Sep 2025 13:50:26 -0500 Subject: [PATCH 08/46] vulkan: support im2col_3d (#15795) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 157 +++++++++++++++++- .../ggml-vulkan/vulkan-shaders/im2col_3d.comp | 112 +++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 4 + tests/test-backend-ops.cpp | 24 ++- 4 files changed, 290 insertions(+), 7 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 5185ede5d6..36951bceb8 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -554,6 +554,7 @@ struct vk_device_struct { vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; + vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_conv_transpose_1d_f32; vk_pipeline pipeline_pool2d_f32; @@ -982,6 +983,37 @@ struct vk_op_im2col_push_constants { int32_t d0; int32_t d1; }; +struct vk_op_im2col_3d_push_constants { + uint32_t nb10; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t s0; + uint32_t s1; + uint32_t s2; + uint32_t p0; + uint32_t p1; + uint32_t p2; + uint32_t d0; + uint32_t d1; + uint32_t d2; + uint32_t IW; + uint32_t IH; + uint32_t ID; + uint32_t IC; + uint32_t KW; + uint32_t OH; + uint32_t KD_KH_KW; + uint32_t KH_KW; + uint32_t IC_KD_KH_KW; + uint32_t N_OD_OH; + uint32_t OD_OH; + uint32_t OD_OH_OW_IC_KD_KH_KW; + uint32_t OH_OW_IC_KD_KH_KW; + uint32_t OW_IC_KD_KH_KW; + uint32_t misalign_offsets; +}; + struct vk_op_timestep_embedding_push_constants { uint32_t nb1; uint32_t dim; @@ -3380,10 +3412,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); if (device->float_controls_rte_fp16) { ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); } else { ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); @@ -7717,6 +7752,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_im2col_f32_f16; } return nullptr; + case GGML_OP_IM2COL_3D: + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_im2col_3d_f32; + } + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_im2col_3d_f32_f16; + } + return nullptr; case GGML_OP_TIMESTEP_EMBEDDING: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_timestep_embedding_f32; @@ -7832,6 +7875,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_RMS_NORM: case GGML_OP_CONV_2D_DW: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_SET_ROWS: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: @@ -7890,6 +7934,16 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src2); } +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src0); + GGML_UNUSED(src2); +} + template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); @@ -8130,6 +8184,26 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { OW * KW * KH, OH, batch * IC }; } break; + case GGML_OP_IM2COL_3D: + { + const uint32_t IC = ((const uint32_t *)(dst->op_params))[9]; + + const uint32_t N = ne13 / IC; + + const uint32_t KD = ne02; + const uint32_t KH = ne01; + const uint32_t KW = ne00; + + const uint32_t OD = ned3 / N; + const uint32_t OH = ned2; + const uint32_t OW = ned1; + + const uint32_t IC_KD_KH_KW = IC*KD*KH*KW; + const uint32_t N_OD_OH = N*OD*OH; + + elements = { IC_KD_KH_KW, OW, N_OD_OH }; + elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + } break; case GGML_OP_TIMESTEP_EMBEDDING: { const uint32_t dim = dst->op_params[0]; @@ -8286,7 +8360,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); - } else if (op == GGML_OP_IM2COL) { + } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) { // im2col uses only src1 and dst buffers ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_COUNT_EQUAL) { @@ -9147,6 +9221,66 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co }, dryrun); } +static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; + const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; + const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; + const int32_t IC = ((const int32_t *)(dst->op_params))[9]; + + const int64_t N = ne13 / IC; + const int64_t ID = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + const int64_t KD = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; + + const int64_t OD = ne3 / N; + const int64_t OH = ne2; + const int64_t OW = ne1; + + vk_op_im2col_3d_push_constants pc {}; + + pc.nb10 = nb10 / ggml_type_size(src1->type); + pc.nb11 = nb11 / ggml_type_size(src1->type); + pc.nb12 = nb12 / ggml_type_size(src1->type); + pc.nb13 = nb13 / ggml_type_size(src1->type); + pc.s0 = s0; + pc.s1 = s1; + pc.s2 = s2; + pc.p0 = p0; + pc.p1 = p1; + pc.p2 = p2; + pc.d0 = d0; + pc.d1 = d1; + pc.d2 = d2; + pc.IW = IW; + pc.IH = IH; + pc.ID = ID; + pc.IC = IC; + pc.KW = KW; + pc.OH = OH; + pc.KD_KH_KW = KD*KH*KW; + pc.KH_KW = KH*KW; + pc.IC_KD_KH_KW = IC*KD*KH*KW; + pc.N_OD_OH = N*OD*OH; + pc.OD_OH = OD*OH; + pc.OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW; + pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; + pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun); +} + static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { const uint32_t dim = dst->op_params[0]; const uint32_t max_period = dst->op_params[1]; @@ -10352,6 +10486,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: @@ -10422,6 +10557,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: @@ -10717,6 +10853,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_IM2COL: ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun); + break; + case GGML_OP_IM2COL_3D: + ggml_vk_im2col_3d(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_TIMESTEP_EMBEDDING: ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun); @@ -10868,6 +11008,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: @@ -12150,6 +12291,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_2D_DW: case GGML_OP_POOL_2D: @@ -12725,6 +12867,19 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * const bool is_2D = tensor->op_params[6] == 1; tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type); + } else if (tensor->op == GGML_OP_IM2COL_3D) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t s1 = tensor->op_params[2]; + const int32_t p0 = tensor->op_params[3]; + const int32_t p1 = tensor->op_params[4]; + const int32_t p1 = tensor->op_params[5]; + const int32_t d0 = tensor->op_params[6]; + const int32_t d1 = tensor->op_params[7]; + const int32_t d1 = tensor->op_params[8]; + const int32_t IC = tensor->op_params[9]; + + tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type); } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) { const int32_t dim = tensor->op_params[0]; const int32_t max_period = tensor->op_params[1]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp new file mode 100644 index 0000000000..3b010bdeb5 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp @@ -0,0 +1,112 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "rte.comp" + +layout (push_constant) uniform parameter +{ + uint32_t nb10; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t s0; + uint32_t s1; + uint32_t s2; + uint32_t p0; + uint32_t p1; + uint32_t p2; + uint32_t d0; + uint32_t d1; + uint32_t d2; + uint32_t IW; + uint32_t IH; + uint32_t ID; + uint32_t IC; + uint32_t KW; + uint32_t OH; + uint32_t KD_KH_KW; + uint32_t KH_KW; + uint32_t IC_KD_KH_KW; + uint32_t N_OD_OH; + uint32_t OD_OH; + uint32_t OD_OH_OW_IC_KD_KH_KW; + uint32_t OH_OW_IC_KD_KH_KW; + uint32_t OW_IC_KD_KH_KW; + uint32_t misalign_offsets; +} p; + +#include "types.comp" + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint32_t i = gl_GlobalInvocationID.x; + + uint32_t nb10 = p.nb10; + uint32_t nb11 = p.nb11; + uint32_t nb12 = p.nb12; + uint32_t nb13 = p.nb13; + uint32_t s0 = p.s0; + uint32_t s1 = p.s1; + uint32_t s2 = p.s2; + uint32_t p0 = p.p0; + uint32_t p1 = p.p1; + uint32_t p2 = p.p2; + uint32_t d0 = p.d0; + uint32_t d1 = p.d1; + uint32_t d2 = p.d2; + uint32_t IW = p.IW; + uint32_t IH = p.IH; + uint32_t ID = p.ID; + uint32_t IC = p.IC; + uint32_t KW = p.KW; + uint32_t OH = p.OH; + uint32_t KD_KH_KW = p.KD_KH_KW; + uint32_t KH_KW = p.KH_KW; + uint32_t IC_KD_KH_KW = p.IC_KD_KH_KW; + uint32_t N_OD_OH = p.N_OD_OH; + uint32_t OD_OH = p.OD_OH; + uint32_t OD_OH_OW_IC_KD_KH_KW = p.OD_OH_OW_IC_KD_KH_KW; + uint32_t OH_OW_IC_KD_KH_KW = p.OH_OW_IC_KD_KH_KW; + uint32_t OW_IC_KD_KH_KW = p.OW_IC_KD_KH_KW; + + if (i >= IC_KD_KH_KW) { + return; + } + + const uint32_t iic = i / KD_KH_KW; + const uint32_t ikd = (i - iic * KD_KH_KW) / KH_KW; + const uint32_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW; + const uint32_t ikw = i % KW; + + const uint32_t iow = gl_GlobalInvocationID.y; + for (uint32_t iz = gl_GlobalInvocationID.z; iz < N_OD_OH; iz += gl_NumWorkGroups.z) { + const uint32_t in_ = iz / OD_OH; + const uint32_t iod = (iz - in_*OD_OH) / OH; + const uint32_t ioh = iz % OH; + + const uint32_t iiw = iow * s0 + ikw * d0 - p0; + const uint32_t iih = ioh * s1 + ikh * d1 - p1; + const uint32_t iid = iod * s2 + ikd * d2 - p2; + + const uint32_t offset_dst = in_*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + + if (iih >= IH || iiw >= IW || iid >= ID) { + data_d[offset_dst + get_doffset()] = D_TYPE(0.0f); + } else { + const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10; + data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]); + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 93cdfd09a9..27394c170f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -713,6 +713,10 @@ void process_shaders() { string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}})); + string_to_spv("im2col_3d_f32", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("im2col_3d_f32_f16", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); + string_to_spv("im2col_3d_f32_f16_rte", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}})); + string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f98a1d2eb5..1f04f3d992 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -300,6 +300,7 @@ static std::string var_to_str(ggml_scale_mode mode) { #define VARS_TO_STR13(a, b, c, d, e, f, g, h, i, j, k, l, m) VAR_TO_STR(a) + "," + VARS_TO_STR12(b, c, d, e, f, g, h, i, j, k, l, m) #define VARS_TO_STR14(a, b, c, d, e, f, g, h, i, j, k, l, m, n) VAR_TO_STR(a) + "," + VARS_TO_STR13(b, c, d, e, f, g, h, i, j, k, l, m, n) #define VARS_TO_STR15(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o) VAR_TO_STR(a) + "," + VARS_TO_STR14(b, c, d, e, f, g, h, i, j, k, l, m, n, o) +#define VARS_TO_STR16(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p) VAR_TO_STR(a) + "," + VARS_TO_STR15(b, c, d, e, f, g, h, i, j, k, l, m, n, o, p) #ifdef GGML_USE_SYCL static bool inline _isinf(float f) { @@ -4047,9 +4048,10 @@ struct test_im2col_3d : public test_case { const int d2; const int64_t IC; + const bool v; std::string vars() override { - return VARS_TO_STR15(type_input, type_kernel, dst_type, ne_input, ne_kernel, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2); + return VARS_TO_STR16(type_input, type_kernel, dst_type, ne_input, ne_kernel, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, v); } test_im2col_3d(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, ggml_type dst_type = GGML_TYPE_F32, @@ -4058,14 +4060,20 @@ struct test_im2col_3d : public test_case { int64_t IC = 3, int s0 = 1, int s1 = 1, int s2 = 1, int p0 = 1, int p1 = 1, int p2 = 1, - int d0 = 1, int d1 = 1, int d2 = 1) - : type_input(type_input), type_kernel(type_kernel), dst_type(dst_type), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), IC(IC) {} + int d0 = 1, int d1 = 1, int d2 = 1, + bool v = false) + : type_input(type_input), type_kernel(type_kernel), dst_type(dst_type), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), IC(IC), v(v) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data()); ggml_set_param(input); ggml_set_name(input, "input"); + if (v) { + input = ggml_view_4d(ctx, input, ne_input[0] - 2, ne_input[1] - 2, ne_input[2] - 2, ne_input[3] - 2, input->nb[1], input->nb[2], input->nb[3], 0); + ggml_set_name(input, "view_of_input"); + } + ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data()); ggml_set_name(kernel, "kernel"); @@ -5729,9 +5737,13 @@ static std::vector> make_test_cases_eval() { for (int d0 : {1, 3}) { for (int d1 : {1, 3}) { for (int d2 : {1, 3}) { - test_cases.emplace_back(new test_im2col_3d( - GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 10, 3}, {3, 3, 3, 3}, - 3, s0, s1, s2, p0, p1, p2, d0, d1, d2)); + for (int IC : {1, 3}) { + for (bool v : {false, true}) { + test_cases.emplace_back(new test_im2col_3d( + GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 10, 3}, {3, 3, 3, 3}, + IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, v)); + } + } } } } From 85ca66a74676e6d5df4433016488e039a4b464ae Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Mon, 8 Sep 2025 10:03:29 +0800 Subject: [PATCH 09/46] CANN: Stream sync between devices for acl_graph (#15809) * CANN: Switch to stream synchronization Switch to stream synchronization because events are not effective. Co-authored-by: hipudding * CANN: add Comments --------- Co-authored-by: hipudding --- ggml/src/ggml-cann/ggml-cann.cpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 756ad8dfad..2f9f373f54 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2092,16 +2092,17 @@ static bool ggml_backend_cann_cpy_tensor_async( ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, cann_ctx_src->stream())); - // record event on src stream after the copy - if (!cann_ctx_src->copy_event) { - ACL_CHECK(aclrtCreateEventWithFlag(&cann_ctx_src->copy_event, ACL_EVENT_SYNC)); - } - ACL_CHECK(aclrtRecordEvent(cann_ctx_src->copy_event, cann_ctx_src->stream())); + // TODO: this event is not effective with acl graph mode, change to use aclrtSynchronizeStream + // if (!cann_ctx_src->copy_event) { + // ACL_CHECK(aclrtCreateEventWithFlag(&cann_ctx_src->copy_event, ACL_EVENT_SYNC)); + // } + // ACL_CHECK(aclrtRecordEvent(cann_ctx_src->copy_event, cann_ctx_src->stream())); - // wait on dst stream for the copy to complete - ggml_cann_set_device(cann_ctx_dst->device); - ACL_CHECK(aclrtStreamWaitEvent(cann_ctx_dst->stream(), cann_ctx_src->copy_event)); + // // wait on dst stream for the copy to complete + // ggml_cann_set_device(cann_ctx_dst->device); + // ACL_CHECK(aclrtStreamWaitEvent(cann_ctx_dst->stream(), cann_ctx_src->copy_event)); + ACL_CHECK(aclrtSynchronizeStream(cann_ctx_src->stream())); } else { // src and dst are on the same backend ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, From d413dca00360a7e4cb71441dacecfa32556fcc31 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sun, 7 Sep 2025 23:23:41 -0500 Subject: [PATCH 10/46] tests: large sizes for get_rows (#15687) --- tests/test-backend-ops.cpp | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 1f04f3d992..f4cd8f65cf 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1957,24 +1957,25 @@ struct test_get_rows : public test_case { const int n; // cols const int m; // rows const int r; // rows to get - const int b; // batch size + const int be1; // batch size + const int be2; // batch size const bool v; // view (non-contiguous src1) std::string vars() override { - return VARS_TO_STR6(type, n, m, r, b, v); + return VARS_TO_STR7(type, n, m, r, be1, be2, v); } - test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, bool v = false) - : type(type), n(n), m(m), r(r), b(b), v(v) {} + test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int be1 = 1, int be2 = 1, bool v = false) + : type(type), n(n), m(m), r(r), be1(be1), be2(be2), v(v) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * in = ggml_new_tensor_3d(ctx, type, n, m, b); + ggml_tensor * in = ggml_new_tensor_4d(ctx, type, n, m, be1, be2); ggml_set_name(in, "in"); - ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b); + ggml_tensor * rows = ggml_new_tensor_3d(ctx, GGML_TYPE_I32, r, be1, be2); ggml_set_name(rows, "rows"); if (v) { - rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0); + rows = ggml_view_3d(ctx, rows, r/2, be1, be2, rows->nb[1], rows->nb[2], 0); ggml_set_name(rows, "view_of_rows"); } @@ -1995,11 +1996,11 @@ struct test_get_rows : public test_case { if (t->type == GGML_TYPE_I32) { if (ggml_is_view_op(t->op)) { continue; } // rows - std::vector data(r*b); - for (int i = 0; i < r*b; i++) { + std::vector data(r*be1*be2); + for (int i = 0; i < r*be1*be2; i++) { data[i] = rand() % m; } - ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int)); + ggml_backend_tensor_set(t, data.data(), 0, r * be1 * be2 * sizeof(int)); } else { init_tensor_uniform(t); } @@ -5620,17 +5621,23 @@ static std::vector> make_test_cases_eval() { } } - test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false)); + for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_Q4_0}) { + test_cases.emplace_back(new test_get_rows(type, 300*256, 5, 4, 1, 2, false)); + test_cases.emplace_back(new test_get_rows(type, 256, 80000, 70000, 2, 1, false)); + test_cases.emplace_back(new test_get_rows(type, 256, 5, 4, 700, 100, false)); + } + + test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, 1, false)); for (ggml_type type : all_types) { for (int b : {1, 7}) { for (bool v : {false, true}) { - test_cases.emplace_back(new test_get_rows(type, 256, 5, 4, b, v)); + test_cases.emplace_back(new test_get_rows(type, 256, 5, 4, b, 1, v)); } } } for (int b : {1, 7}) { for (bool v : {false, true}) { - test_cases.emplace_back(new test_get_rows(GGML_TYPE_I32, 256, 5, 4, b, v)); + test_cases.emplace_back(new test_get_rows(GGML_TYPE_I32, 256, 5, 4, b, 1, v)); } } From cf0e3ba1500bd23635b444f64ba23cbdb56c92ef Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 Sep 2025 10:25:33 +0300 Subject: [PATCH 11/46] model : avoid ggml_cont_3d for fused QKV weights (#15662) * model : avoid ggml_cont_3d for fused QKV weights ggml-ci * kv-cache : make cpy_k and cpy_v implementation more readable ggml-ci * cont : add comments ggml-ci * cont : minor fix [no ci] * cont : one more fix * cont : clarity ggml-ci * kv-cache : require contiguous heads of k_cur and v_cur ggml-ci --- src/llama-kv-cache.cpp | 72 +++++++++++++++++------ src/llama-kv-cache.h | 8 +++ src/llama-model.cpp | 128 ++++++++++++----------------------------- 3 files changed, 100 insertions(+), 108 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index ae35f74201..885be072a7 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1018,16 +1018,33 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm const int32_t ikv = map_layer_ids.at(il); - auto * k = layers[ikv].k; + ggml_tensor * k = layers[ikv].k; - const int64_t n_tokens = k_cur->ne[2]; + const int64_t n_embd_head = k_cur->ne[0]; + const int64_t n_head = k_cur->ne[1]; + const int64_t n_tokens = k_cur->ne[2]; - k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens); + const int64_t n_embd_gqa = n_embd_head*n_head; - if (k->ne[2] > 1) { - k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]); + // we can merge dims 0 and 1 + // TODO: add ggml helper function for this? + GGML_ASSERT(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]); + + k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0); + + const int64_t n_stream = k->ne[2]; + + if (n_stream > 1) { + const int64_t kv_size = get_size(); + + assert(n_embd_gqa == k->ne[0]); + assert(kv_size == k->ne[1]); + + // merge the buffer across all streams because the idxs are global + k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream); } + // store the current K values into the cache return ggml_set_rows(ctx, k, k_cur, k_idxs); } @@ -1038,28 +1055,51 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm auto * v = layers[ikv].v; - const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1]; - const int64_t n_tokens = v_cur->ne[2]; + const int64_t n_embd_head = v_cur->ne[0]; + const int64_t n_head = v_cur->ne[1]; + const int64_t n_tokens = v_cur->ne[2]; - v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens); + const int64_t n_embd_gqa = n_embd_head*n_head; + // we can merge dims 0 and 1 + GGML_ASSERT(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]); + + const int64_t n_stream = v->ne[2]; + + // take this branch when FA is enabled (the V cache is not transposed) if (!v_trans) { - if (v->ne[2] > 1) { - v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]); + v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_tokens, v_cur->nb[2], 0); + + if (n_stream > 1) { + const int64_t kv_size = get_size(); + + assert(n_embd_gqa == v->ne[0]); + assert(kv_size == v->ne[1]); + + // merge the buffer across all streams because the idxs are global + v = ggml_reshape_2d(ctx, v, n_embd_gqa, kv_size*n_stream); } return ggml_set_rows(ctx, v, v_cur, v_idxs); } - // [TAG_V_CACHE_VARIABLE] - if (n_embd_v_gqa < v->ne[0]) { - v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0); + if (ggml_row_size(v_cur->type, n_embd_gqa) == v_cur->nb[2]) { + // we can merge dims 0, 1 and 2 + v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens); + } else { + // otherwise -> make a copy to get contiguous data + v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_tokens); } - // the row becomes a single element - ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]); + // [TAG_V_CACHE_VARIABLE] + if (n_embd_gqa < v->ne[0]) { + v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_gqa, 0, 0, 0); + } - v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]); + // in this branch the v_idxs are constructed in such a way that each row is a single head element + ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, ggml_nelements(v)); + + v_cur = ggml_reshape_2d(ctx, v_cur, 1, ggml_nelements(v_cur)); return ggml_set_rows(ctx, v_view, v_cur, v_idxs); } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index b545bf5b9c..30de013f5f 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -317,9 +317,17 @@ public: ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; // store k_cur and v_cur in the cache based on the provided head location + // note: the heads in k_cur and v_cur should be layed out contiguously in memory + // - k_cur [n_embd_head_k, n_head_k, n_tokens] + // - k_idxs [n_tokens] + // - v_cur [n_embd_head_v, n_head_v, n_tokens] + // - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const; ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const; + // create destination indices for each head of the current batch for where it would be written in the KV cache + // the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but + // helps understand the implementation logic of cpy_k and cpy_v ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 1813f06d7b..b9e4634a70 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6927,9 +6927,7 @@ struct llm_build_falcon : public llm_graph_context { ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); // using mode = 2 for neox mode Qcur = ggml_rope_ext( @@ -7207,9 +7205,7 @@ struct llm_build_dbrx : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -7329,13 +7325,9 @@ struct llm_build_starcoder : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -7551,14 +7543,16 @@ struct llm_build_bert : public llm_graph_context { cb(cur, "bqkv", il); } - Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } @@ -7569,8 +7563,6 @@ struct llm_build_bert : public llm_graph_context { LLM_NORM, il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - } else { - Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); } if (model.layers[il].attn_k_norm) { @@ -7580,8 +7572,6 @@ struct llm_build_bert : public llm_graph_context { LLM_NORM, il); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - } else { - Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); } // RoPE @@ -7727,9 +7717,7 @@ struct llm_build_neo_bert : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); // RoPE Qcur = ggml_rope_ext( @@ -7836,13 +7824,9 @@ struct llm_build_bloom : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -7958,13 +7942,9 @@ struct llm_build_mpt : public llm_graph_context { cb(cur, "wqkv_clamped", il); } - ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); // Q/K Layernorm if (model.layers[il].attn_q_norm) { @@ -7972,26 +7952,16 @@ struct llm_build_mpt : public llm_graph_context { model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il); - cb(Qcur, "Qcur", il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il); - cb(Kcur, "Kcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - } else { - Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cb(Qcur, "Qcur", il); - - Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - cb(Kcur, "Kcur", il); } - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); @@ -8240,11 +8210,9 @@ struct llm_build_qwen : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd)); - - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 2*sizeof(float)*(n_embd)); // using mode = 2 for neox mode Qcur = ggml_rope_ext( @@ -9219,21 +9187,17 @@ struct llm_build_phi2 : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -9357,21 +9321,17 @@ struct llm_build_phi3 : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd)); - Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -9621,18 +9581,14 @@ struct llm_build_gpt2 : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); @@ -9727,9 +9683,7 @@ struct llm_build_codeshell : public llm_graph_context { ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -12601,9 +12555,7 @@ struct llm_build_gptneox : public llm_graph_context { ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -13736,18 +13688,14 @@ struct llm_build_jais : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd)); - ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd)); - ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*cur->nb[0]*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/float(n_embd_head), il); @@ -13859,8 +13807,7 @@ struct llm_build_chatglm : public llm_graph_context { } Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); } //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor); @@ -13993,8 +13940,7 @@ struct llm_build_glm4 : public llm_graph_context { } Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); } Qcur = ggml_rope_ext( @@ -17293,16 +17239,14 @@ private: const int64_t k_offset = n_embd_head_q * n_head; const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv; - ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv)); + ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv)); ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv)); - ggml_tensor * Vcur = ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, n_embd_head_v, n_head_kv, n_tokens, n_embd_head_v * sizeof(float), qkv->nb[1], v_offset * ggml_element_size(qkv)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens); - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); From 663027fd5490438ce9c27ea866a560e1e268d11f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 Sep 2025 10:26:36 +0300 Subject: [PATCH 12/46] context : fix n_outputs during reserve (#15858) ggml-ci --- src/llama-context.cpp | 4 ++-- src/llama-graph.cpp | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 6b3188be4b..874c6f82cb 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -285,8 +285,8 @@ llama_context::llama_context( const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - // avoid reserving graphs with zero outputs - n_outputs = 1; + // avoid reserving graphs with zero outputs - assume one output per sequence + n_outputs = n_seqs; LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 4abb6008dd..7f254b25cd 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1431,7 +1431,8 @@ ggml_tensor * llm_graph_context::build_attn( // [TAG_NO_CACHE_PAD] // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams - assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq)); + // but it might not be worth it: https://github.com/ggml-org/llama.cpp/pull/15636 + //assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq)); ggml_tensor * q = q_cur; ggml_tensor * k = k_cur; From a885dcff11a7b73f9377812d6151f6b15d307de0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 Sep 2025 10:27:07 +0300 Subject: [PATCH 13/46] batched-bench : fix llama_synchronize usage during prompt processing (#15835) ggml-ci --- tools/batched-bench/batched-bench.cpp | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/tools/batched-bench/batched-bench.cpp b/tools/batched-bench/batched-bench.cpp index 46dd12caae..fcfcd80771 100644 --- a/tools/batched-bench/batched-bench.cpp +++ b/tools/batched-bench/batched-bench.cpp @@ -71,7 +71,7 @@ int main(int argc, char ** argv) { llama_batch batch = llama_batch_init(n_kv_max, 0, 1); // decode in batches of ctx_params.n_batch tokens - auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) { + auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch, bool synchronize) { for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); @@ -91,7 +91,9 @@ int main(int argc, char ** argv) { return false; } - llama_synchronize(ctx); + if (synchronize) { + llama_synchronize(ctx); + } } return true; @@ -103,7 +105,7 @@ int main(int argc, char ** argv) { common_batch_add(batch, get_token_rand(), i, { 0 }, false); } - if (!decode_helper(ctx, batch, ctx_params.n_batch)) { + if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -138,15 +140,17 @@ int main(int argc, char ** argv) { } } - const auto t_pp_start = ggml_time_us(); - llama_memory_clear(mem, false); - if (!decode_helper(ctx, batch, ctx_params.n_batch)) { + const auto t_pp_start = ggml_time_us(); + + if (!decode_helper(ctx, batch, ctx_params.n_batch, false)) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } + llama_synchronize(ctx); + const auto t_pp_end = ggml_time_us(); if (is_pp_shared) { @@ -158,7 +162,7 @@ int main(int argc, char ** argv) { // run one dummy token to apply the memory copy common_batch_clear(batch); common_batch_add(batch, get_token_rand(), pp + 0, { 0 }, true); - if (!decode_helper(ctx, batch, ctx_params.n_batch)) { + if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -175,7 +179,7 @@ int main(int argc, char ** argv) { common_batch_add(batch, get_token_rand(), pp + i, { j }, true); } - if (!decode_helper(ctx, batch, ctx_params.n_batch)) { + if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } From 233d773d02c37982badddf3994e9953a175f34da Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 8 Sep 2025 09:44:34 +0200 Subject: [PATCH 14/46] convert : force setting sliding_window from original config (#15867) * convert : force setting sliding_window from original config This commit modifies the set_gguf_parameters method for EmbeddingGemma so that it reads the sliding_window parameter from the original model config.json and uses that value. The motivation for this change is that the Gemma3TextConfig constructor adjusts the sliding_window value, which can lead to inconsistencies when converting models as we expects this value to match the original model's configuration. Refs: https://github.com/huggingface/transformers/blob/bb45d3631ec7026db04a77d33a52b31766372160/src/transformers/models/gemma3/configuration_gemma3.py#L230 * fix flake8 error * add link to huggingface PR --- convert_hf_to_gguf.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 717b8a6595..62a546ee22 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5128,6 +5128,20 @@ class EmbeddingGemma(Gemma3Model): def set_gguf_parameters(self): super().set_gguf_parameters() + + # Override the sliding window size as it gets adjusted by the Gemma3TextConfig + # constructor. We want to use the value from the original model's config.json. + # ref: https://github.com/huggingface/transformers/pull/40700 + with open(self.dir_model / "config.json", "r", encoding="utf-8") as f: + config = json.load(f) + orig_sliding_window = config.get("sliding_window") + if orig_sliding_window is None: + raise ValueError("sliding_window not found in model config - this is required for the model") + + logger.info(f"Using original sliding_window from config: {orig_sliding_window} " + f"instead of {self.hparams['sliding_window']}") + self.gguf_writer.add_sliding_window(orig_sliding_window) + self._try_set_pooling_type() From 5ef22d281de9c5eaaf616874bc490b89241128cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Mon, 8 Sep 2025 11:55:44 +0200 Subject: [PATCH 15/46] CUDA: non-contiguous src0 not supported for PAD (#15869) --- ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 0c01eb6fa8..9f5dcd9341 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3574,9 +3574,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_GROUP_NORM: + case GGML_OP_PAD: return ggml_is_contiguous(op->src[0]); case GGML_OP_UPSCALE: - case GGML_OP_PAD: case GGML_OP_PAD_REFLECT_1D: case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: From 9fcb29f22f5c33c04c7f0daebb24057899d67a1a Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Mon, 8 Sep 2025 17:33:01 +0700 Subject: [PATCH 16/46] ggml: allow casting between f32 and i32 (#15783) * ggml: allow casting between f32 and i32 * fix cuda * add vulkan * fix CPU non-cont * add non-cont test case * add note * extend test number range * correct note * add cont version for vulkan --- ggml/include/ggml-cpu.h | 1 + ggml/include/ggml.h | 1 + ggml/src/ggml-cpu/ggml-cpu.c | 15 +- ggml/src/ggml-cpu/ops.cpp | 160 ++++++++++++++++++ ggml/src/ggml-cuda/convert.cuh | 2 + ggml/src/ggml-cuda/cpy.cu | 4 + ggml/src/ggml-cuda/ggml-cuda.cu | 6 + ggml/src/ggml-metal/ggml-metal.m | 15 ++ ggml/src/ggml-metal/ggml-metal.metal | 2 + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 29 +++- .../vulkan-shaders/vulkan-shaders-gen.cpp | 4 + tests/test-backend-ops.cpp | 11 ++ 12 files changed, 247 insertions(+), 3 deletions(-) diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index 1a78935aa0..9edd485136 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -134,6 +134,7 @@ extern "C" { GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void); GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t); + GGML_BACKEND_API void ggml_cpu_fp32_to_i32 (const float *, int32_t *, int64_t); GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t); GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t); GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index c01b98ac78..058f4267f7 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1404,6 +1404,7 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + // note: casting from f32 to i32 will discard the fractional part GGML_API struct ggml_tensor * ggml_cast( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 09772e8061..c131290849 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -373,6 +373,9 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_I32] = { + .from_float = (ggml_from_float_t) ggml_cpu_fp32_to_i32, + }, }; const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) { @@ -2696,7 +2699,10 @@ struct ggml_cplan ggml_graph_plan( if (ggml_is_quantized(node->type) || // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32 (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) || - (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) { + (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16) || + // conversion between F32 and I32 + (node->src[0]->type == GGML_TYPE_F32 && node->src[1] && node->src[1]->type == GGML_TYPE_I32) || + (node->src[0]->type == GGML_TYPE_I32 && node->src[1] && node->src[1]->type == GGML_TYPE_F32)) { cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; } } break; @@ -3258,6 +3264,13 @@ void ggml_cpu_fp32_to_bf16(const float * x, ggml_bf16_t * y, int64_t n) { } } +void ggml_cpu_fp32_to_i32(const float * x, int32_t * y, int64_t n) { + int64_t i = 0; + for (; i < n; ++i) { + y[i] = x[i]; + } +} + void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) { int64_t i = 0; #if defined(__AVX2__) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 0bb767e01a..9adf910769 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -776,6 +776,24 @@ static void ggml_compute_forward_dup_f32( id += ne00 * (ne01 - ir1); } } + } else if (dst->type == GGML_TYPE_I32) { + size_t id = 0; + int32_t * dst_ptr = (int32_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } } else { GGML_ABORT("fatal error"); // TODO: implement } @@ -947,6 +965,144 @@ static void ggml_compute_forward_dup_f32( } } } + } else if (dst->type == GGML_TYPE_I32) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(int32_t *) dst_ptr = *(const float *) src0_ptr; + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else { + GGML_ABORT("fatal error"); // TODO: implement + } +} + +static void ggml_compute_forward_dup_i32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + + GGML_TENSOR_UNARY_OP_LOCALS + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + // dst counters + + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + // TODO: not optimal, but works + if (dst->type == GGML_TYPE_F32) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(float *) dst_ptr = *(const int32_t *) src0_ptr; + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } } else { GGML_ABORT("fatal error"); // TODO: implement } @@ -1177,6 +1333,10 @@ void ggml_compute_forward_dup( { ggml_compute_forward_dup_f32(params, dst); } break; + case GGML_TYPE_I32: + { + ggml_compute_forward_dup_i32(params, dst); + } break; default: { if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) { diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index c62e8a1b10..ef9e129950 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -38,6 +38,8 @@ template return __float2bfloat16(float(x)); } else if constexpr(std::is_same_v) { return __bfloat162float(x); + } else if constexpr(std::is_same_v) { + return int32_t(x); } else { return float(x); } diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index c40db08ced..8567c3d5a1 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -374,6 +374,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 9f5dcd9341..a88b9f75ef 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3461,6 +3461,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) { return true; } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) { + return true; + } + if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) { + return true; + } if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) { return true; } diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index c1a0a2bef1..578bdd6eca 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -583,6 +583,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_CPY_F16_F32, GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, + GGML_METAL_KERNEL_TYPE_CPY_F32_I32, + GGML_METAL_KERNEL_TYPE_CPY_I32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, @@ -1616,6 +1618,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_I32, cpy_f32_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_I32_F32, cpy_i32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); @@ -1945,6 +1949,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_I32: return true; default: return false; @@ -1977,6 +1982,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex default: return false; } + case GGML_TYPE_I32: + return op->type == GGML_TYPE_F32; default: return false; }; @@ -5680,6 +5687,7 @@ static int ggml_metal_encode_node( switch (dstt) { case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; + case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_I32].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break; case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; @@ -5691,6 +5699,13 @@ static int ggml_metal_encode_node( default: GGML_ABORT("not implemented"); }; } break; + case GGML_TYPE_I32: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_I32_F32].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; case GGML_TYPE_F16: { switch (dstt) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2d56c62674..4dc762bf1a 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -5338,6 +5338,8 @@ typedef decltype(kernel_cpy) kernel_cpy_t; template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy; template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy; #if defined(GGML_METAL_USE_BF16) template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy; #endif diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 36951bceb8..e6245d0cfc 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -506,8 +506,8 @@ struct vk_device_struct { vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; - vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16; - vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16; + vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; + vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT]; @@ -3226,12 +3226,16 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); if (device->float_controls_rte_fp16) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); @@ -5693,6 +5697,20 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cpy_f32_bf16; } } + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_i32; + } else { + return ctx->device->pipeline_cpy_f32_i32; + } + } + if (src->type == GGML_TYPE_I32 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_i32_f32; + } else { + return ctx->device->pipeline_cpy_i32_f32; + } + } if (src->type == GGML_TYPE_F32) { switch (to) { case GGML_TYPE_Q4_0: @@ -12224,6 +12242,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return true; } + if ( + src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32 || + src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32 + ) { + return true; + } + // We can handle copying from a type to the same type if it's // contiguous (memcpy). We use f16 or f32 shaders to do the copy, // so the type/block size must be a multiple of 4. diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 27394c170f..b6570e0203 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -560,10 +560,14 @@ void process_shaders() { string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("contig_cpy_f32_i32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); + string_to_spv("contig_cpy_i32_f32", "contig_copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); + string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f4cd8f65cf..e8e6e1000b 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2457,6 +2457,13 @@ struct test_cpy : public test_case { return out; } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + // test extended range of values to check if casting between f32 and i32 is consistent + init_tensor_uniform(t, -150.f, 150.f); + } + } }; // GGML_OP_CONT @@ -6007,6 +6014,10 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {1, 0, 2, 3})); // cpy not-contiguous } } + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}, {1, 0, 2, 3})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}, {1, 0, 2, 3})); test_cases.emplace_back(new test_cont()); test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1})); From f28d4f4ac963f182ea9d0fe9e269f3f5f3782aaf Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 Sep 2025 13:34:56 +0300 Subject: [PATCH 17/46] metal : refactor + optimize (#15857) * metal : refactor ggml-ci * cont : refactor FA-vec kernel * cont : print metal library load time * minor : warn to debug + bettern kernel names ggml-ci * metal : optimize mul_mv q8_0 ggml-ci * metal : simplify FA pipeline creation functions ggml-ci * metal : improve naming consistency * metal : safer function constants offsets ggml-ci * metal : comments ggml-ci --- ggml/src/ggml-metal/ggml-metal-impl.h | 48 +- ggml/src/ggml-metal/ggml-metal.m | 1077 +++++++---------- ggml/src/ggml-metal/ggml-metal.metal | 1611 +++++++++++++++---------- tests/test-backend-ops.cpp | 4 +- 4 files changed, 1415 insertions(+), 1325 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index b9d3639448..651943fa92 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -20,8 +20,8 @@ #define N_R0_Q5_1 4 #define N_SG_Q5_1 2 -#define N_R0_Q8_0 4 -#define N_SG_Q8_0 2 +#define N_R0_Q8_0 2 +#define N_SG_Q8_0 4 #define N_R0_MXFP4 2 #define N_SG_MXFP4 2 @@ -68,6 +68,11 @@ #define N_R0_IQ4_XS 2 #define N_SG_IQ4_XS 2 +// function constants offsets +#define FC_FLASH_ATTN_EXT 100 +#define FC_FLASH_ATTN_EXT_VEC 200 +#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300 + // kernel argument structs // // - element counters (e.g. ne00) typically use int32_t to reduce register usage @@ -236,9 +241,11 @@ typedef struct { int32_t ne11; int32_t ne_12_2; // assume K and V are same shape int32_t ne_12_3; + int32_t ns10; uint64_t nb11; uint64_t nb12; uint64_t nb13; + int32_t ns20; uint64_t nb21; uint64_t nb22; uint64_t nb23; @@ -258,10 +265,43 @@ typedef struct { float logit_softcap; } ggml_metal_kargs_flash_attn_ext; +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + int32_t ns10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ns20; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; + int32_t ne1; + int32_t ne2; + int32_t ne3; + float scale; + float max_bias; + float m0; + float m1; + int32_t n_head_log2; + float logit_softcap; +} ggml_metal_kargs_flash_attn_ext_vec; + typedef struct { int32_t nrows; - int32_t ne20; -} ggml_metal_kargs_flash_attn_ext_reduce; +} ggml_metal_kargs_flash_attn_ext_vec_reduce; typedef struct { int32_t ne00; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 578bdd6eca..eeb6c9d4b7 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -174,6 +174,19 @@ struct ggml_metal_kernel { id pipeline; }; +@interface ggml_metal_kernel_wrapper : NSObject + +@property (nonatomic, assign) struct ggml_metal_kernel kernel; + +@end + +@implementation ggml_metal_kernel_wrapper +- (void) dealloc { + [_kernel.pipeline release]; + [super dealloc]; +} +@end + enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ADD, GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, @@ -454,126 +467,6 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, GGML_METAL_KERNEL_TYPE_SET_I32, GGML_METAL_KERNEL_TYPE_SET_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, @@ -884,8 +777,12 @@ struct ggml_backend_metal_context { dispatch_queue_t d_queue; + // the set of pre-compiled kernels for this context struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT]; + // additional, inference-time compiled kernels + NSMutableDictionary * kernels_ext; + // capture state bool capture_next_compute; bool capture_started; @@ -951,6 +848,8 @@ static void * ggml_metal_host_malloc(size_t n) { // - if not found, load the source and compile it // - if that fails, return NULL static id ggml_metal_load_library(id device, bool use_bfloat) { + const int64_t t_start = ggml_time_us(); + id metal_library = nil; NSError * error = nil; NSString * src = nil; @@ -1074,6 +973,8 @@ static id ggml_metal_load_library(id device, bool use_bfl [src release]; #endif // GGML_METAL_EMBED_LIBRARY + GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6); + return metal_library; } @@ -1271,7 +1172,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); @@ -1489,126 +1390,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40, flash_attn_ext_f16_h40, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40, flash_attn_ext_bf16_h40, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40, flash_attn_ext_q4_0_h40, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40, flash_attn_ext_q4_1_h40, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40, flash_attn_ext_q5_0_h40, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40, flash_attn_ext_q5_1_h40, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40, flash_attn_ext_q8_0_h40, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, flash_attn_ext_vec_q4_1_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, flash_attn_ext_vec_q5_0_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, flash_attn_ext_vec_q5_1_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, flash_attn_ext_vec_f16_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, flash_attn_ext_vec_bf16_h192, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, flash_attn_ext_vec_q4_0_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, flash_attn_ext_vec_q4_1_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, flash_attn_ext_vec_q5_0_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, flash_attn_ext_vec_q5_1_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, flash_attn_ext_vec_q8_0_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, flash_attn_ext_vec_f16_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, flash_attn_ext_vec_bf16_hk192_hv128, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, flash_attn_ext_vec_q4_0_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, flash_attn_ext_vec_q4_1_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, flash_attn_ext_vec_q5_0_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, flash_attn_ext_vec_q5_1_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, flash_attn_ext_vec_q8_0_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, flash_attn_ext_reduce, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); @@ -1655,9 +1436,219 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); } + ctx->kernels_ext = [[NSMutableDictionary alloc] init]; + return ctx; } +static id ggml_metal_get_kernel(struct ggml_backend_metal_context * ctx, const char * name) { + NSString * key = [NSString stringWithUTF8String:name]; + + ggml_metal_kernel_wrapper * obj = [ctx->kernels_ext objectForKey:key]; + if (obj) { + return obj.kernel.pipeline; + } + + return nil; +} + +static id ggml_metal_compile_kernel(ggml_backend_t backend, const char * base, const char * name, MTLFunctionConstantValues * cv) { + struct ggml_backend_metal_context * ctx = backend->context; + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + id res = nil; + + @autoreleasepool { + NSError * error = nil; + + NSString * base_func = [NSString stringWithUTF8String:base]; + + GGML_LOG_DEBUG("%s: compiling kernel: base = '%s', name = '%s'\n", __func__, base, name); + + // TODO: make sure it is thread-safe to compile kernels in parallel + id metal_function = [ctx_dev->mtl_library newFunctionWithName:base_func constantValues:cv error:&error]; + if (!metal_function) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + + return nil; + } + + struct ggml_metal_kernel kernel = { + /*.pipeline =*/ [ctx_dev->mtl_device newComputePipelineStateWithFunction:metal_function error:&error], + }; + + ggml_metal_kernel_wrapper * obj = [[ggml_metal_kernel_wrapper alloc] init]; + obj.kernel = kernel; + + res = obj.kernel.pipeline; + + NSString * key = [NSString stringWithUTF8String:name]; + [ctx->kernels_ext setObject:obj forKey:key]; + + GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) kernel.pipeline, + (int) kernel.pipeline.maxTotalThreadsPerThreadgroup, + (int) kernel.pipeline.threadExecutionWidth); + } + + return res; +} + +static id ggml_metal_get_pipeline_flash_attn_ext( + ggml_backend_t backend, struct ggml_tensor * op, + bool has_mask, + bool has_sinks, + bool has_bias, + bool has_scap, + int32_t nsg) { + struct ggml_backend_metal_context * ctx = backend->context; + + char base[256]; + char name[256]; + + @autoreleasepool { + MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init]; + + const int32_t dk = (int32_t) op->src[1]->ne[0]; + const int32_t dv = (int32_t) op->src[2]->ne[0]; + + const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0]; + const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0]; + + snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d", + "flash_attn_ext", + ggml_type_name(op->src[1]->type), + dk, + dv); + + snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d", + "flash_attn_ext", + ggml_type_name(op->src[1]->type), + dk, + dv, + has_mask, + has_sinks, + has_bias, + has_scap, + ns10, + ns20, + nsg); + + id res = ggml_metal_get_kernel(ctx, name); + if (res) { + // kernel found + return res; + } + + cv = [[MTLFunctionConstantValues alloc] init]; + + [cv setConstantValue:&has_mask type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 0]; + [cv setConstantValue:&has_sinks type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 1]; + [cv setConstantValue:&has_bias type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 2]; + [cv setConstantValue:&has_scap type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 3]; + + [cv setConstantValue:&ns10 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 20]; + [cv setConstantValue:&ns20 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 21]; + [cv setConstantValue:&nsg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 22]; + + return ggml_metal_compile_kernel(backend, base, name, cv); + } +} + +static id ggml_metal_get_pipeline_flash_attn_ext_vec( + ggml_backend_t backend, struct ggml_tensor * op, + bool has_mask, + bool has_sinks, + bool has_bias, + bool has_scap, + int32_t nsg, + int32_t nwg) { + struct ggml_backend_metal_context * ctx = backend->context; + + char base[256]; + char name[256]; + + @autoreleasepool { + MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init]; + + const int32_t dk = (int32_t) op->src[1]->ne[0]; + const int32_t dv = (int32_t) op->src[2]->ne[0]; + + const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0]; + const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0]; + + snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d", + "flash_attn_ext_vec", + ggml_type_name(op->src[1]->type), + dk, + dv); + + snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d", + "flash_attn_ext_vec", + ggml_type_name(op->src[1]->type), + dk, + dv, + has_mask, + has_sinks, + has_bias, + has_scap, + ns10, + ns20, + nsg, nwg); + + id res = ggml_metal_get_kernel(ctx, name); + if (res) { + // kernel found + return res; + } + + cv = [[MTLFunctionConstantValues alloc] init]; + + [cv setConstantValue:&has_mask type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 0]; + [cv setConstantValue:&has_sinks type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 1]; + [cv setConstantValue:&has_bias type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 2]; + [cv setConstantValue:&has_scap type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 3]; + + [cv setConstantValue:&ns10 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 20]; + [cv setConstantValue:&ns20 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 21]; + [cv setConstantValue:&nsg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 22]; + [cv setConstantValue:&nwg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 23]; + + return ggml_metal_compile_kernel(backend, base, name, cv); + } +} + +static id ggml_metal_get_pipeline_flash_attn_ext_vec_reduce( + ggml_backend_t backend, struct ggml_tensor * op, + int32_t dv, + int32_t nwg) { + struct ggml_backend_metal_context * ctx = backend->context; + + char base[256]; + char name[256]; + + @autoreleasepool { + MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init]; + + snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce"); + snprintf(name, 256, "kernel_flash_attn_ext_vec_reduce_dv=%d_nwg=%d", dv, nwg); + + id res = ggml_metal_get_kernel(ctx, name); + if (res) { + // kernel found + return res; + } + + cv = [[MTLFunctionConstantValues alloc] init]; + + [cv setConstantValue:&dv type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC_REDUCE + 0]; + [cv setConstantValue:&nwg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC_REDUCE + 1]; + + return ggml_metal_compile_kernel(backend, base, name, cv); + } + + GGML_UNUSED(op); +} + static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { GGML_LOG_INFO("%s: deallocating\n", __func__); @@ -1665,6 +1656,11 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { [ctx->kernels[i].pipeline release]; } + if (ctx->kernels_ext) { + [ctx->kernels_ext release]; + ctx->kernels_ext = nil; + } + Block_release(ctx->encode_async); [ctx->queue release]; @@ -3772,6 +3768,7 @@ static int ggml_metal_encode_node( { nsg = N_SG_Q8_0; nr0 = N_R0_Q8_0; + smem = 32*sizeof(float)*N_R0_Q8_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; } break; case GGML_TYPE_MXFP4: @@ -3908,7 +3905,12 @@ static int ggml_metal_encode_node( if (smem > 0) { [encoder setThreadgroupMemoryLength:smem atIndex:0]; } - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + + if (src0t == GGML_TYPE_Q8_0) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0 - 1)/(nr0), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } } } break; case GGML_OP_MUL_MAT_ID: @@ -4129,6 +4131,7 @@ static int ggml_metal_encode_node( { nsg = N_SG_Q8_0; nr0 = N_R0_Q8_0; + smem = 32*sizeof(float)*N_R0_Q8_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; } break; case GGML_TYPE_MXFP4: @@ -4274,7 +4277,12 @@ static int ggml_metal_encode_node( if (smem > 0) { [encoder setThreadgroupMemoryLength:smem atIndex:0]; } - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + + if (src0t == GGML_TYPE_Q8_0) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } } } break; case GGML_OP_GET_ROWS: @@ -5125,6 +5133,7 @@ static int ggml_metal_encode_node( float scale; float max_bias; float logit_softcap; + memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); memcpy(&logit_softcap, ((const int32_t *) dst->op_params) + 2, sizeof(logit_softcap)); @@ -5133,398 +5142,24 @@ static int ggml_metal_encode_node( scale /= logit_softcap; } + const bool has_mask = src3 != NULL; + const bool has_sinks = src4 != NULL; + const bool has_bias = max_bias != 0.0f; + const bool has_scap = logit_softcap != 0.0f; + const uint32_t n_head = src0->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - id pipeline = nil; - - bool use_vec_kernel = false; + GGML_ASSERT(ne01 < 65536); // use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size - if (ne01 >= 20 || (ne00 == 40 || ne00 == 80 || ne00 == 112)) { - switch (src1->type) { - case GGML_TYPE_F16: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_BF16: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q4_0: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q4_1: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q5_0: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q5_1: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q8_0: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } else { - use_vec_kernel = true; - - switch (ne00) { - case 64: - { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } break; - case 96: - { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } break; - case 128: - { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } break; - case 192: - { - if (ne20 == 128) { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } else { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } - } break; - case 256: - { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } break; - case 576: - { - if (ne20 == 512) { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } else { - GGML_LOG_ERROR("unsupported size: %lld\n", ne20); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - - ggml_metal_kargs_flash_attn_ext args = { - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne11 =*/ ne11, - /*.ne_12_2 =*/ ne12, - /*.ne_12_3 =*/ ne13, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.nb21 =*/ nb21, - /*.nb22 =*/ nb22, - /*.nb23 =*/ nb23, - /*.ne32 =*/ ne32, - /*.ne33 =*/ ne33, - /*.nb31 =*/ nb31, - /*.nb32 =*/ nb32, - /*.nb33 =*/ nb33, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.scale =*/ scale, - /*.max_bias =*/ max_bias, - /*.m0 =*/ m0, - /*.m1 =*/ m1, - /*.n_head_log2 =*/ n_head_log2, - /*.logit_softcap =*/ logit_softcap, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - if (id_src3) { - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; - } - if (id_src4) { - [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5]; - } - - if (!use_vec_kernel) { + if (ne01 >= 20 || (ne00 % 32 != 0)) { // half8x8 kernel const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !! GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 8 == 0); @@ -5532,34 +5167,90 @@ static int ggml_metal_encode_node( const int is_q = ggml_is_quantized(src1->type) ? 1 : 0; - // 2*(2*ncpsg + nqptg)*(nsg) - // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float) + // 2*(2*ncpsg) + // ncpsg soft_max values + ncpsg mask values // // 16*32*(nsg) // the shared memory needed for the simdgroups to load the KV cache // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*GGML_PAD(ne20, 64) + 2*(2*ncpsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16)) - int64_t nsgmax = 2; - - while (true) { - const size_t smem = FATTN_SMEM(nsgmax); - if (smem > device.maxThreadgroupMemoryLength/2) { - break; - } - nsgmax *= 2; - } - nsgmax /= 2; + //int64_t nsgmax = 4; + // + //if (is_q) { + // nsgmax = 2; + // while (true) { + // const size_t smem = FATTN_SMEM(nsgmax); + // if (smem > device.maxThreadgroupMemoryLength/2) { + // break; + // } + // nsgmax *= 2; + // } + // nsgmax /= 2; + //} // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + int32_t nsg = 4; const size_t smem = FATTN_SMEM(nsg); + ggml_metal_kargs_flash_attn_ext args = { + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.ns10 =*/ nb11/nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ns20 =*/ nb21/nb20, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, + /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ logit_softcap, + }; + + id pipeline = ggml_metal_get_pipeline_flash_attn_ext(backend, node, has_mask, has_sinks, has_bias, has_scap, nsg); + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + if (id_src3) { + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; + } + if (id_src4) { + [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); + //printf("smem: %zu, max: %zu, nsg = %d, ne02 = %d, ne12 = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, ne02, ne12); GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; @@ -5568,7 +5259,7 @@ static int ggml_metal_encode_node( // half4x4 kernel const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - const int64_t nkpsg = 1*ncpsg; // TODO: make adjustable + const int64_t nkpsg = 1*ncpsg; GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 1 == 0); @@ -5581,8 +5272,7 @@ static int ggml_metal_encode_node( // ne20*(nsg) // each simdgroup has a full f32 head vector in shared mem to accumulate results // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16)) -//#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16)) int64_t nsgmax = 2; while (true) { @@ -5596,7 +5286,8 @@ static int ggml_metal_encode_node( nsgmax /= 2; // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); + //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); + const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32))); int64_t nsg = 1; while (nsg <= nsgt) { @@ -5606,28 +5297,86 @@ static int ggml_metal_encode_node( // workgroups // each workgroup handles nsg*nkpsg cache values - uint16_t nwg = 1; - if (4*nsg*nkpsg >= ne11) { - const size_t smem = FATTN_SMEM(nsg); + int32_t nwg = 1; + if (false) { + // for small KV caches, we could launch a single workgroup and write the results directly to dst/ + // however, this does not lead to significant improvement, so disabled + nwg = 1; + nsg = 4; + } else { + nwg = 32; + nsg = 1; + while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) { + nsg *= 2; + } + } - //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax); - GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); + ggml_metal_kargs_flash_attn_ext_vec args = { + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.ns10 =*/ nb11/nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ns20 =*/ nb21/nb20, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, + /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ logit_softcap, + }; + id pipeline = ggml_metal_get_pipeline_flash_attn_ext_vec(backend, node, has_mask, has_sinks, has_bias, has_scap, nsg, nwg); + + GGML_ASSERT(nsg*32 <= (int) pipeline.maxTotalThreadsPerThreadgroup); + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + if (id_src3) { + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; + } + if (id_src4) { + [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5]; + } + + const size_t smem = FATTN_SMEM(nsg); + + //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax); + GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); + + if (nwg == 1) { // using 1 workgroup -> write the result directly into dst - [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; [encoder setThreadgroupMemoryLength:smem atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } else { - nwg = 32; - nsg = MIN(4, nsg); - - const size_t smem = FATTN_SMEM(nsg); - - //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax); - GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); - // sanity checks GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3); GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31)); @@ -5647,20 +5396,18 @@ static int ggml_metal_encode_node( //printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20); //printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f); - [encoder setBuffer:h_tmp offset:0 atIndex:6]; - [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7]; + [encoder setBuffer:h_tmp offset:0 atIndex:6]; [encoder setThreadgroupMemoryLength:smem atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; // reduce the results from the workgroups { - ggml_metal_kargs_flash_attn_ext_reduce args0 = { + ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = { nrows, - ne20, }; - id pipeline0 = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline; + id pipeline0 = ggml_metal_get_pipeline_flash_attn_ext_vec_reduce(backend, node, ne20, nwg); [encoder setComputePipelineState:pipeline0]; [encoder setBytes:&args0 length:sizeof(args0) atIndex:0]; @@ -5668,7 +5415,7 @@ static int ggml_metal_encode_node( [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; //printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20); - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*32, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*nwg, 1, 1)]; } } #undef FATTN_SMEM diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 4dc762bf1a..77be3c5c9d 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -15,6 +15,10 @@ using namespace metal; #define MIN(x, y) ((x) < (y) ? (x) : (y)) #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } +#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1)) + +#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x) + #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf @@ -2755,7 +2759,47 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; } -template +template +static inline void helper_mv_reduce_and_write( + device float * dst_f32, + float sumf[NR0], + const int r0, + const int ne01, + ushort tiisg, + ushort sgitg, + threadgroup char * shmem) { + threadgroup float * shmem_f32[NR0]; + + for (short row = 0; row < NR0; ++row) { + shmem_f32[row] = (threadgroup float *) shmem + NW*row; + + if (sgitg == 0) { + shmem_f32[row][tiisg] = 0.0f; + } + + sumf[row] = simd_sum(sumf[row]); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short row = 0; row < NR0; ++row) { + if (tiisg == 0) { + shmem_f32[row][sgitg] = sumf[row]; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short row = 0; row < NR0 && r0 + row < ne01; ++row) { + float tot = simd_sum(shmem_f32[row][tiisg]); + + if (tiisg == 0 && sgitg == 0) { + dst_f32[r0 + row] = tot; + } + } +} + +template void mul_vec_q_n_f32_impl( args_t args, device const char * src0, @@ -2765,45 +2809,51 @@ void mul_vec_q_n_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + constexpr short NQ = 16; + const int nb = args.ne00/QK4_0; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * nsg + sgitg) * nr0; + const int r0 = (tgpig.x*NSG + sgitg)*NR0; + //const int r0 = tgpig.x*NR0; + const int r1 = tgpig.y; + const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows - device const block_q_type * ax[nr0]; - for (int row = 0; row < nr0; ++row) { - const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + device const block_q_type * ax[NR0]; + FOR_UNROLL (int row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } + float sumf[NR0] = {0.f}; + + const short ix = (tiisg/(NW/NQ)); + const short il = (tiisg%(NW/NQ))*8; + + //const int ib0 = sgitg*NQ + ix; + const int ib0 = ix; + float yl[16]; // src1 vector cache - float sumf[nr0] = {0.f}; - const short ix = (tiisg/2); - const short il = (tiisg%2)*8; - - device const float * yb = y + ix*QK4_0 + il; + //device const float * yb = y + ix*QK4_0 + il; + device const float * yb = y + ib0*QK4_0 + il; // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += nw/2) { + //for (int ib = ib0; ib < nb; ib += NSG*NQ) { + for (int ib = ib0; ib < nb; ib += NQ) { float sumy[2] = { 0.f, 0.f }; -#pragma unroll - for (short i = 0; i < 8; i += 2) { + FOR_UNROLL (short i = 0; i < 8; i += 2) { sumy[0] += yb[i + 0] + yb[i + 1]; yl[i + 0] = yb[i + 0]; yl[i + 1] = yb[i + 1]/256.f; @@ -2813,21 +2863,23 @@ void mul_vec_q_n_f32_impl( yl[i + 9] = yb[i + 17]/4096.f; } -#pragma unroll - for (short row = 0; row < nr0; row++) { + FOR_UNROLL (short row = 0; row < NR0; row++) { sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); } yb += QK4_0 * 16; + //yb += NSG*NQ*QK4_0; } device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; - for (int row = 0; row < nr0; ++row) { + //helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); + + for (int row = 0; row < NR0; ++row) { const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < args.ne01) { - dst_f32[first_row + row] = tot; + if (tiisg == 0 && r0 + row < args.ne01) { + dst_f32[r0 + row] = tot; } } } @@ -2837,10 +2889,11 @@ kernel void kernel_mul_mv_q4_0_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -2848,10 +2901,11 @@ kernel void kernel_mul_mv_q4_1_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -2859,10 +2913,11 @@ kernel void kernel_mul_mv_q5_0_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -2870,15 +2925,14 @@ kernel void kernel_mul_mv_q5_1_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -#define NB_Q8_0 8 - -template +template void kernel_mul_mv_q8_0_f32_impl( args_t args, device const char * src0, @@ -2888,66 +2942,65 @@ void kernel_mul_mv_q8_0_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + constexpr short NQ = 8; + const int nb = args.ne00/QK8_0; - const int r0 = tgpig.x; + const int r0 = tgpig.x*NR0; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; - const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows - device const block_q8_0 * ax[nr0]; - for (int row = 0; row < nr0; ++row) { - const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + device const block_q8_0 * ax[NR0]; + FOR_UNROLL (short row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } - float yl[NB_Q8_0]; - float sumf[nr0] = { 0.f }; + float sumf[NR0] = { 0.f }; - const short ix = tiisg/4; - const short il = tiisg%4; + const short ix = tiisg/(NW/NQ); + const short il = tiisg%(NW/NQ); - device const float * yb = y + ix*QK8_0 + il*NB_Q8_0; + const int ib0 = sgitg*NQ + ix; - // each thread in a SIMD group deals with NB_Q8_0 quants at a time - for (int ib = ix; ib < nb; ib += nw/4) { - for (short i = 0; i < NB_Q8_0; ++i) { + float yl[NQ]; + + device const float * yb = y + ib0*QK8_0 + il*NQ; + + // each thread in a SIMD group deals with NQ quants at a time + for (int ib = ib0; ib < nb; ib += NSG*NQ) { + for (short i = 0; i < NQ; ++i) { yl[i] = yb[i]; } - for (short row = 0; row < nr0; row++) { - device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; + for (short row = 0; row < NR0; row++) { + device const int8_t * qs = ax[row][ib].qs + il*NQ; + float sumq = 0.f; - for (short iq = 0; iq < NB_Q8_0; ++iq) { - sumq += qs[iq] * yl[iq]; + FOR_UNROLL (short i = 0; i < NQ; ++i) { + sumq += qs[i] * yl[i]; } + sumf[row] += sumq*ax[row][ib].d; } - yb += nw*NB_Q8_0; + yb += NSG*NQ*QK8_0; } device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0; ++row) { - const float tot = simd_sum(sumf[row]); - - if (tiisg == 0 && first_row + row < args.ne01) { - dst_f32[first_row + row] = tot; - } - } + helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); } [[host_name("kernel_mul_mv_q8_0_f32")]] @@ -2956,10 +3009,11 @@ kernel void kernel_mul_mv_q8_0_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } // mat-vec kernel processing in chunks of float4 @@ -4197,6 +4251,19 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope; } +constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]]; +constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]]; +constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]]; +constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]]; + +//constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; +//constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]]; +//constant float FC_flash_attn_ext_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT + 12)]]; + +constant int32_t FC_flash_attn_ext_ns10 [[function_constant(FC_FLASH_ATTN_EXT + 20)]]; +constant int32_t FC_flash_attn_ext_ns20 [[function_constant(FC_FLASH_ATTN_EXT + 21)]]; +constant int32_t FC_flash_attn_ext_nsg [[function_constant(FC_FLASH_ATTN_EXT + 22)]]; + // ref: https://arxiv.org/pdf/2307.08691.pdf template< typename q_t, // query types in shared memory @@ -4211,6 +4278,7 @@ template< typename qk_t, // Q*K types typename qk8x8_t, typename s_t, // soft-max types + typename s2_t, typename s8x8_t, typename o_t, // attention accumulation types typename o4_t, @@ -4221,12 +4289,12 @@ template< typename vd4x4_t, // value type in device memory short nl_v, void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), - short DK, // K head size - short DV, // V head size - short Q = 8, // queries per threadgroup - short KV = 8, // key/value processed per each simdgroup - short C = 32> // cache items per threadgroup -kernel void kernel_flash_attn_ext( + short DK, // K head size + short DV, // V head size + short Q, // queries per threadgroup + short C, // cache items per threadgroup + short NSG> // number of simd groups +void kernel_flash_attn_ext_impl( constant ggml_metal_kargs_flash_attn_ext & args, device const char * q, device const char * k, @@ -4234,46 +4302,85 @@ kernel void kernel_flash_attn_ext( device const char * mask, device const char * sinks, device char * dst, - threadgroup half * shmem_f16 [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - const short nsg = ntg.y; // number of simdgroups + threadgroup half * shmem_f16, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const ushort iq3 = tgpig[2]; + const ushort iq2 = tgpig[1]; + const ushort iq1 = tgpig[0]*Q; - const int iq3 = tgpig[2]; - const int iq2 = tgpig[1]; - const int iq1 = tgpig[0]*Q; +#define NS10 (FC_flash_attn_ext_ns10) +#define NS20 (FC_flash_attn_ext_ns20) + + // note: I had some concerns that using this instead of the ugly macros above was affecting performance + // need to re-check carefully and if no regressions are observerd - remove the macros + // the concerns is that maybe using const variables requires extra registers? but not sure if the compiler + // is clever enough to avoid this. unfortunately, using constexpr is not possible with FC + //const short NS10 = FC_flash_attn_ext_ns10; + //const short NS20 = FC_flash_attn_ext_ns20; + + constexpr short KV = 8; constexpr short DK4 = DK/4; constexpr short DK8 = DK/8; constexpr short DK16 = DK/16; constexpr short DV4 = DV/4; - constexpr short DV8 = DV/8; + //constexpr short DV8 = DV/8; constexpr short DV16 = DV/16; + constexpr short PV = PAD2(DV, 64); + constexpr short PV4 = PV/4; + constexpr short PV8 = PV/8; + //constexpr short PV16 = PV/16; + constexpr short NW = N_SIMDWIDTH; - constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) + constexpr short NQ = Q/NSG; + constexpr short SH = 2*C; // shared memory per simdgroup (s_t == float) - const short TS = nsg*SH; // shared memory size per query in (s_t == float) - const short T = 2*DK + 2*TS; // shared memory size per query in (half) + constexpr short TS = 2*SH; + constexpr short T = DK + 2*PV; // shared memory size per query in (half) - threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*T); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*T); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*T + Q*DK); // the result for all queries in 8x8 matrices (the O matrix from the paper) + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*T + Q*DK); + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + Q*T); // scratch buffer for attention, mask and diagonal matrix + threadgroup s2_t * ss2 = (threadgroup s2_t *) (shmem_f16 + Q*T); // same as above but in s2_t - threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory - threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t + threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load K in shared memory + threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in k4x4_t - threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory - threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t + threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load V in shared memory + threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in v4x4_t - // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - o8x8_t lo[DV8]; + // mask storage in shared mem + threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C); + + // per-query mask pointers + device const half2 * pm2[NQ]; + + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); + } + + { + q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += ikv2*args.nb12 + ikv3*args.nb13; + v += ikv2*args.nb22 + ikv3*args.nb23; + } // load heads from Q to shared memory - for (short j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + device const float4 * q4 = (device const float4 *) ((device const char *) q + j*args.nb01); for (short i = tiisg; i < DK4; i += NW) { if (iq1 + j < args.ne01) { @@ -4284,43 +4391,30 @@ kernel void kernel_flash_attn_ext( } } - // zero out lo - for (short i = 0; i < DV8; ++i) { - lo[i] = make_filled_simdgroup_matrix((o_t) 0.0f); - } + // zero out + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + for (short i = tiisg; i < DV4; i += NW) { + so4[j*PV4 + i] = 0; + } - // zero out shared memory SH - for (short j = 0; j < Q; ++j) { for (short i = tiisg; i < SH; i += NW) { - ss[j*TS + i] = 0.0f; + ss[j*SH + i] = 0.0f; } } threadgroup_barrier(mem_flags::mem_threadgroup); + float S[NQ] = { [0 ... NQ-1] = 0.0f }; + { - float S[Q] = { [0 ... Q-1] = 0.0f }; - float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 }; - - // thread indices inside the simdgroup - // TODO: see if we can utilize quad-group functions for better performance - // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3) - const short tx = tiisg%4; - const short ty = tiisg/4; - - // broadcast kv - //const short rk2 = args.ne02/args.ne12; - //const short rk3 = args.ne03/args.ne13; - - const short ikv2 = iq2/(args.ne02/args.ne_12_2); - const short ikv3 = iq3/(args.ne03/args.ne_12_3); - - const bool has_mask = mask != q; + float M[NQ] = { [0 ... NQ-1] = -FLT_MAX/2 }; float slope = 1.0f; // ALiBi - if (args.max_bias > 0.0f) { + if (FC_flash_attn_ext_has_bias) { const short h = iq2; const float base = h < args.n_head_log2 ? args.m0 : args.m1; @@ -4331,177 +4425,277 @@ kernel void kernel_flash_attn_ext( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { - const int ic = ic0 + C*sgitg; - if (ic >= args.ne11) { - break; - } + for (int ic = 0; ic < args.ne11; ic += C) { + // read the mask into shared mem + if (FC_flash_attn_ext_has_mask) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; - if (has_mask) { - // used to detect blocks full of -INF - float smax = -INFINITY; - - // load the mask in shared memory - #pragma unroll(Q) - for (short j = 0; j < Q; ++j) { - device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); - - const float m = pm[ic + tiisg]; - - ss[j*TS + C + tiisg] = m; - smax = max(smax, m); + sm2[j*SH + tiisg] = pm2[jj][tiisg]; + pm2[jj] += NW; } - smax = simd_max(smax); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // used to detect blocks full of -INF + // skip only when the entire threadgroup is masked + half2 smax2(-MAXHALF/2, -MAXHALF/2); + + FOR_UNROLL (short j = 0; j < Q; ++j) { + smax2 = max(smax2, sm2[j*SH + tiisg]); + } + + smax2 = simd_max(smax2); + + if (max(smax2[0], smax2[1]) <= -MAXHALF/2) { + // this barrier is important + threadgroup_barrier(mem_flags::mem_threadgroup); - if (smax == -INFINITY) { continue; } } // Q*K^T - { - for (short cc = 0; cc < C/8; ++cc) { + // this is compile-time check, so it does not have runtime overhead + if (is_same::value) { + // we can read directly from global memory + device const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11); + threadgroup const q_t * pq = sq; + threadgroup s_t * ps = ss; + + pk += sgitg*(8*NS10); + ps += sgitg*(8*1); + + static_assert((C/8) % NSG == 0, ""); + + constexpr short NC = (C/8)/NSG; + + // TODO: not good to unroll for large contexts - not sure why? + for (short cc = 0; cc < NC; ++cc) { qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); - // this is compile-time check, so it does not have runtime overhead - if (is_same::value) { - // we can read directly from global memory - device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + if (DK8 % 16 != 0) { + k8x8_t mk; + q8x8_t mq; - #pragma unroll(DK8) - for (short i = 0; i < DK8; ++i) { - k8x8_t mk; - simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10 + FOR_UNROLL (short i = 0; i < DK8; ++i) { + simdgroup_barrier(mem_flags::mem_none); + + simdgroup_load(mk, pk, NS10, 0, true); + simdgroup_load(mq, pq, DK); + + simdgroup_barrier(mem_flags::mem_none); - q8x8_t mq; - simdgroup_load(mq, sq + i*8, DK); simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + pk += 8; + pq += 8; } } else { - for (short ii = 0; ii < DK16; ii += 4) { - device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + k8x8_t mk[2]; + q8x8_t mq[2]; - if (DK16%4 == 0) { - // the head is evenly divisible by 4*16 = 64, so no need for bound checks - { - k4x4_t tmp; - deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); - sk4x4[4*ty + tx] = tmp; - } + FOR_UNROLL (short i = 0; i < DK8/2; ++i) { + simdgroup_barrier(mem_flags::mem_none); - simdgroup_barrier(mem_flags::mem_threadgroup); + simdgroup_load(mk[0], pk + 0*8, NS10, 0, true); + simdgroup_load(mk[1], pk + 1*8, NS10, 0, true); - #pragma unroll(4) - for (short k = 0; k < 4; ++k) { - k8x8_t mk; - q8x8_t mq; + simdgroup_load(mq[0], pq + 0*8, DK); + simdgroup_load(mq[1], pq + 1*8, DK); - simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); - } - } else { - if (ii + tx < DK16) { - k4x4_t tmp; - deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); - sk4x4[4*ty + tx] = tmp; - } + simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk); + simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk); - simdgroup_barrier(mem_flags::mem_threadgroup); + pk += 16; + pq += 16; + } + } - for (short k = 0; k < 4 && ii + k < DK16; ++k) { - k8x8_t mk; - q8x8_t mq; + simdgroup_store(mqk, ps, SH, 0, false); - simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + pk += 8*(NSG*NS10 - DK8); + pq += 8*(NSG*0 - DK8); + ps += 8*(NSG); + } + } else { + // TODO: this is the quantized K cache branch - not optimized yet + for (short ccc = 0; ccc < (C/8)/NSG; ++ccc) { + const short cc = ccc*NSG + sgitg; - simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); - } + const short tx = tiisg%4; + const short ty = tiisg/4; + + qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); + + for (short ii = 0; ii < DK16; ii += 4) { + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11)); + + if (DK16%4 == 0) { + // the head is evenly divisible by 4*16 = 64, so no need for bound checks + { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short k = 0; k < 4; ++k) { + k8x8_t mk; + q8x8_t mq; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + } + } else { + if (ii + tx < DK16) { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + for (short k = 0; k < 4 && ii + k < DK16; ++k) { + k8x8_t mk; + q8x8_t mq; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); } } } - // cast qk_t -> s_t - //s8x8_t mqks(1.0f); - //simdgroup_multiply(mqks, mqk, mqks); - //simdgroup_store(mqks, ss + 8*cc, TS, 0, false); - - simdgroup_store(mqk, ss + 8*cc, TS, 0, false); + simdgroup_store(mqk, ss + 8*cc, SH, 0, false); } } + threadgroup_barrier(mem_flags::mem_threadgroup); + // online softmax - { - for (ushort j = 0; j < Q; ++j) { - const float m = M[j]; + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; - // scale and apply the logitcap / mask - float s = ss[j*TS + tiisg]*args.scale; + const float m = M[jj]; - if (args.logit_softcap != 0.0f) { - s = args.logit_softcap*precise::tanh(s); + // scale and apply the logitcap / mask + float2 s2 = ss2[j*SH/2 + tiisg]*args.scale; + + if (FC_flash_attn_ext_has_scap) { + s2 = args.logit_softcap*precise::tanh(s2); + } + + // mqk = mqk + slope*mask + if (FC_flash_attn_ext_has_bias) { + s2 += s2_t(sm2[j*SH + tiisg])*slope; + } else { + s2 += s2_t(sm2[j*SH + tiisg]); + } + + M[jj] = simd_max(max(M[jj], max(s2[0], s2[1]))); + + const float ms = exp(m - M[jj]); + const float2 vs2 = exp(s2 - M[jj]); + + S[jj] = S[jj]*ms + simd_sum(vs2[0] + vs2[1]); + + // the P matrix from the paper (Q rows, C columns) + ss2[j*SH/2 + tiisg] = vs2; + + if (DV4 % NW == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) { + const short i = ii*NW + tiisg; + + so4[j*PV4 + i] *= ms; } - - // mqk = mqk + mask*slope - s += slope*ss[j*TS + C + tiisg]; - - M[j] = simd_max(max(M[j], s)); - - const float ms = exp(m - M[j]); - const float vs = exp(s - M[j]); - - S[j] = S[j]*ms + simd_sum(vs); - - // the P matrix from the paper (Q rows, C columns) - ss[j*TS + tiisg] = vs; - - // create a QxQ diagonal matrix for rescaling the output - if (tiisg == j) { - ss[j*TS + 2*C + j] = ms; + } else { + for (short i = tiisg; i < DV4; i += NW) { + so4[j*PV4 + i] *= ms; } } } - // O = diag(ms)*O - { - s8x8_t ms; - simdgroup_load(ms, ss + 2*C, TS, 0, false); - - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - simdgroup_multiply(lo[i], ms, lo[i]); - } - } + threadgroup_barrier(mem_flags::mem_threadgroup); // O = O + (Q*K^T)*V { - for (short cc = 0; cc < C/8; ++cc) { - s8x8_t vs; - simdgroup_load(vs, ss + 8*cc, TS, 0, false); + // we can read directly from global memory + if (is_same::value) { + static_assert(PV8 % NSG == 0, ""); - if (is_same::value) { - // we can read directly from global memory - device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + constexpr short NO = PV8/NSG; - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - v8x8_t mv; - simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20 + o8x8_t lo[NO]; - simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]); + { + auto sot = so + 8*sgitg; + + FOR_UNROLL (short ii = 0; ii < NO; ++ii) { + simdgroup_load(lo[ii], sot, PV, 0, false); + + sot += 8*NSG; } - } else { - for (short ii = 0; ii < DV16; ii += 4) { - device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + } + + { + auto sst = ss; + + device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21); + + pv += 8*sgitg; + + FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { + s8x8_t vs; + simdgroup_load(vs, sst, SH, 0, false); + + FOR_UNROLL (short ii = 0; ii < NO; ++ii) { + v8x8_t mv; + + simdgroup_load(mv, pv, NS20, 0, false); + simdgroup_multiply_accumulate(lo[ii], vs, mv, lo[ii]); + + pv += 8*NSG; + } + + pv += 8*(NS20 - NO*NSG); + sst += 8; + } + } + + { + auto sot = so + 8*sgitg; + + FOR_UNROLL (short ii = 0; ii < NO; ++ii) { + simdgroup_store(lo[ii], sot, PV, 0, false); + + sot += 8*NSG; + } + } + } else { + // TODO: this is the quantized V cache branch - not optimized yet + + const short tx = tiisg%4; + const short ty = tiisg/4; + + for (short cc = 0; cc < C/8; ++cc) { + s8x8_t vs; + simdgroup_load(vs, ss + 8*cc, SH, 0, false); + + for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) { + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21)); if (DV16%4 == 0) { // no need for bound checks @@ -4513,15 +4707,20 @@ kernel void kernel_flash_attn_ext( simdgroup_barrier(mem_flags::mem_threadgroup); - #pragma unroll(4) - for (short k = 0; k < 4; ++k) { - v8x8_t mv; + FOR_UNROLL (short k = 0; k < 4; ++k) { + v8x8_t mv[2]; + o8x8_t lo[2]; - simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); + simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); - simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); + simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]); + simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]); + + simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); } } else { if (ii + tx < DV16) { @@ -4533,243 +4732,249 @@ kernel void kernel_flash_attn_ext( simdgroup_barrier(mem_flags::mem_threadgroup); for (short k = 0; k < 4 && ii + k < DV16; ++k) { - v8x8_t mv; + v8x8_t mv[2]; + o8x8_t lo[2]; - simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); + simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); - simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); + simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]); + simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]); + + simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); } } } } } } + + threadgroup_barrier(mem_flags::mem_threadgroup); } - if (sinks != q && sgitg == 0) { - for (ushort j = 0; j < Q; ++j) { - const float m = M[j]; + if (FC_flash_attn_ext_has_sinks) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + const float m = M[jj]; const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; - M[j] = simd_max(max(M[j], s)); + M[jj] = simd_max(max(M[jj], s)); - const float ms = exp(m - M[j]); - const float vs = exp(s - M[j]); + const float ms = exp(m - M[jj]); + const float vs = exp(s - M[jj]); - S[j] = S[j]*ms + simd_sum(vs); + S[jj] = S[jj]*ms + simd_sum(vs); - if (tiisg == j) { - ss[j*TS + 2*C + j] = ms; - } - } - - // O = diag(ms)*O - { - s8x8_t ms; - simdgroup_load(ms, ss + 2*C, TS, 0, false); - - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - simdgroup_multiply(lo[i], ms, lo[i]); + for (short i = tiisg; i < DV4; i += NW) { + so4[j*PV4 + i] *= ms; } } } - - // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (short j = tiisg; j < Q; j += NW) { - ss[j*TS + 0] = S[j]; - ss[j*TS + 1] = M[j]; - } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation - threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK); - - // store result to shared memory in F32 - if (sgitg == 0) { - for (short i = 0; i < DV8; ++i) { - //simdgroup_store(lo[i], so + i*8, DV, 0, false); - simdgroup_float8x8 t(1.0f); - simdgroup_multiply(t, lo[i], t); - simdgroup_store(t, so + i*8, DV, 0, false); + // store to global memory + for (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + if (iq1 + j >= args.ne01) { + break; } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // reduce the warps sequentially - for (ushort sg = 1; sg < nsg; ++sg) { - if (sgitg == sg) { - for (short j = tiisg; j < Q; j += NW) { - const float S0 = ss[j*TS - 1*SH + 0]; - const float S1 = ss[j*TS + 0]; - - const float M0 = ss[j*TS - 1*SH + 1]; - const float M1 = ss[j*TS + 1]; - - const float M = max(M0, M1); - - float ms0 = exp(M0 - M); - float ms1 = exp(M1 - M); - - const float S = S0*ms0 + S1*ms1; - - ss[j*TS + 0] = S; - ss[j*TS + 1] = M; - - ss[j*TS + 2*C + j - 1*SH] = ms0; - ss[j*TS + 2*C + j ] = ms1; - } - - //simdgroup_barrier(mem_flags::mem_threadgroup); - - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - { - s8x8_t ms0; - s8x8_t ms1; - - simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false); - simdgroup_load(ms1, ss + 2*C, TS, 0, false); - - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - simdgroup_float8x8 t; - - simdgroup_load (t, so + i*8, DV, 0, false); - simdgroup_multiply(t, ms0, t); - - simdgroup_multiply_accumulate(t, ms1, lo[i], t); - simdgroup_store(t, so + i*8, DV, 0, false); - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK); - - // final rescale with 1/S and store to global memory - for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) { - const float S = 1.0f/sf[j*TS + 0]; device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4; - for (short i = tiisg; i < DV4; i += NW) { - dst4[i] = (float4) so4[j*DV4 + i]*S; + const float scale = 1.0f/S[jj]; + + if (DV4 % NW == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) { + const short i = ii*NW + tiisg; + + dst4[i] = (float4) so4[j*PV4 + i]*scale; + } + } else { + for (short i = tiisg; i < DV4; i += NW) { + dst4[i] = (float4) so4[j*PV4 + i]*scale; + } } } + +#undef NS10 +#undef NS20 +} + +template< + typename q_t, // query types in shared memory + typename q4_t, + typename q8x8_t, + typename k_t, // key types in shared memory + typename k4x4_t, + typename k8x8_t, + typename v_t, // value types in shared memory + typename v4x4_t, + typename v8x8_t, + typename qk_t, // Q*K types + typename qk8x8_t, + typename s_t, // soft-max types + typename s2_t, + typename s8x8_t, + typename o_t, // attention accumulation types + typename o4_t, + typename o8x8_t, + typename kd4x4_t, // key type in device memory + short nl_k, + void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), + typename vd4x4_t, // value type in device memory + short nl_v, + void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), + short DK, // K head size + short DV, // V head size + short Q = 8, // queries per threadgroup + short C = 64> // cache items per threadgroup +kernel void kernel_flash_attn_ext( + constant ggml_metal_kargs_flash_attn_ext & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device const char * sinks, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { +#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C +#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg + switch (FC_flash_attn_ext_nsg) { + // note: disabled cases to reduce library load time + //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; + //case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break; + case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break; + } +#undef FWD_TMPL +#undef FWD_ARGS } // TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as // template to be able to explore different combinations // #define FA_TYPES \ - float, float4, simdgroup_float8x8, \ + half, half4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ float, simdgroup_float8x8, \ - float, simdgroup_float8x8, \ - half, half4, simdgroup_half8x8 - //float, float4, simdgroup_float8x8 + float, float2, simdgroup_float8x8, \ + float, float4, simdgroup_float8x8 + //half, half4, simdgroup_half8x8 #define FA_TYPES_BF \ bfloat, bfloat4, simdgroup_bfloat8x8, \ bfloat, bfloat4x4, simdgroup_bfloat8x8, \ bfloat, bfloat4x4, simdgroup_bfloat8x8, \ float, simdgroup_float8x8, \ - float, simdgroup_float8x8, \ + float, float2, simdgroup_float8x8, \ half, half4, simdgroup_half8x8 //float, float4, simdgroup_float8x8 typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; -template [[host_name("kernel_flash_attn_ext_f16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_bf16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif -template [[host_name("kernel_flash_attn_ext_q4_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES #undef FA_TYPES_BF +constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]]; +constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]]; +constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]]; +constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]]; + +//constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]]; +//constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]]; +//constant float FC_flash_attn_ext_vec_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 12)]]; + +constant int32_t FC_flash_attn_ext_vec_ns10 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 20)]]; +constant int32_t FC_flash_attn_ext_vec_ns20 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 21)]]; +constant int32_t FC_flash_attn_ext_vec_nsg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 22)]]; +constant int32_t FC_flash_attn_ext_vec_nwg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 23)]]; + template< typename q4_t, // query types in shared memory typename k4_t, // key types in shared memory @@ -4788,63 +4993,86 @@ template< short DV, // V head size short NE = 4, // head elements per thread short Q = 1, // queries per threadgroup - short C = 32> // cache items per threadgroup -kernel void kernel_flash_attn_ext_vec( - constant ggml_metal_kargs_flash_attn_ext & args, + short C = 32, // cache items per threadgroup + short NSG> // number of simd groups +void kernel_flash_attn_ext_vec_impl( + constant ggml_metal_kargs_flash_attn_ext_vec & args, device const char * q, device const char * k, device const char * v, device const char * mask, device const char * sinks, device char * dst, - constant uint16_t & nwg, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { static_assert(DK % 32 == 0, "DK must be divisible by 32"); static_assert(DV % 32 == 0, "DV must be divisible by 32"); - const short nsg = ntg.y; // number of simdgroups - const short iwg = tgpig[2]%nwg; +#define NWG (FC_flash_attn_ext_vec_nwg) - const int iq3 = tgpig[2]/nwg; - const int iq2 = tgpig[1]; - const int iq1 = tgpig[0]; +#define NS10 (FC_flash_attn_ext_vec_ns10) +#define NS20 (FC_flash_attn_ext_vec_ns20) + + const short iwg = tgpig[2]%NWG; + + const ushort iq3 = tgpig[2]/NWG; + const ushort iq2 = tgpig[1]; + const ushort iq1 = tgpig[0]; constexpr short DK4 = DK/4; constexpr short DV4 = DV/4; + + constexpr short PK = PAD2(DK, 128); + constexpr short PK4 = PK/4; + + constexpr short PV = PAD2(DV, 128); + constexpr short PV4 = PV/4; + constexpr short NW = N_SIMDWIDTH; constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads constexpr short SH = 4*C; // shared memory per simdgroup - const short T = DK + nsg*SH; // shared memory size per query in (half) + static_assert(DK4 % NL == 0, "DK4 must be divisible by NL"); + static_assert(DV4 % NL == 0, "DV4 must be divisible by NL"); - //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t - threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask - threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results + const short T = PK + NSG*SH; // shared memory size per query in (half) - // store the result for all queries in local memory (the O matrix from the paper) - o4_t lo[DV4/NL]; + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*PK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*PK); // same as above but in s4_t + threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + Q*PK); // scratch buffer for mask + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + Q*T); // scratch buffer for the results + + // store the result for all queries in shared memory (the O matrix from the paper) + so4 += tiisg; + + { + q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += ikv2*args.nb12 + ikv3*args.nb13; + v += ikv2*args.nb22 + ikv3*args.nb23; + } // load heads from Q to shared memory - device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); + device const float4 * q4 = (device const float4 *) ((device const char *) q); - for (short i = tiisg; i < DK4; i += NW) { - if (iq1 < args.ne01) { + for (short i = tiisg; i < PK4; i += NW) { + if (iq1 < args.ne01 && i < DK4) { sq4[i] = (q4_t) q4[i]; } else { sq4[i] = (q4_t) 0.0f; } } - // zero out lo + // zero out so for (short i = 0; i < DV4/NL; ++i) { - lo[i] = (o4_t) 0.0f; + so4[i*NL] = (o4_t) 0.0f; } // zero out shared memory SH @@ -4856,28 +5084,19 @@ kernel void kernel_flash_attn_ext_vec( { float S = 0.0f; - float M = -__FLT_MAX__/2; + float M = -FLT_MAX/2; // thread indices inside the simdgroup const short tx = tiisg%NL; const short ty = tiisg/NL; - // broadcast kv - //const short rk2 = args.ne02/args.ne12; - //const short rk3 = args.ne03/args.ne13; - - const short ikv2 = iq2/(args.ne02/args.ne_12_2); - const short ikv3 = iq3/(args.ne03/args.ne_12_3); - - const bool has_mask = mask != q; - // pointer to the mask device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); float slope = 1.0f; // ALiBi - if (args.max_bias > 0.0f) { + if (FC_flash_attn_ext_vec_has_bias) { const short h = iq2; const float base = h < args.n_head_log2 ? args.m0 : args.m1; @@ -4888,13 +5107,13 @@ kernel void kernel_flash_attn_ext_vec( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = (int) iwg*C*nsg; ic0 < args.ne11; ic0 += (int) nwg*C*nsg) { + for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) { const int ic = ic0 + C*sgitg; if (ic >= args.ne11) { break; } - if (has_mask) { + if (FC_flash_attn_ext_vec_has_mask) { sm[tiisg] = pm[ic + tiisg]; } @@ -4905,70 +5124,82 @@ kernel void kernel_flash_attn_ext_vec( // Q*K^T { - // each simdgroup processes 1 query and NE (NW/NL) head elements - for (short cc = 0; cc < C/NE; ++cc) { - qk_t mqk = 0.0f; + device const k4_t * pk4 = (device const k4_t *) ((device const char *) k + ic*args.nb11); + threadgroup const q4_t * pq4 = sq4; - device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + pk4 += ty*NS10/4 + tx; + pq4 += tx; - #pragma unroll(DK4/NL) - for (short ii = 0; ii < DK4; ii += NL) { - const short i = ii + tx; + qk_t mqk[C/NE] = { [ 0 ... C/NE - 1] = 0.0f }; + + // each simdgroup processes 1 query and NE (NW/NL) cache elements + FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { + if (is_same::value) { + FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) { + mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]); + } + } else { + device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11)); k4_t mk; - deq_k_t4(pk + i/nl_k, i%nl_k, mk); - // note: this is less precise than the version below - //mqka[0] += dot(mq[0], mk[0]); - //mqka[1] += dot(mq[1], mk[1]); - //mqka[2] += dot(mq[2], mk[2]); - //mqka[3] += dot(mq[3], mk[3]); + FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) { + const short i = ii*NL + tx; - //q4x4_t mq = sq4x4[i]; - //mqka[0] += dot((float4) mq[0], (float4) mk[0]); - //mqka[1] += dot((float4) mq[1], (float4) mk[1]); - //mqka[2] += dot((float4) mq[2], (float4) mk[2]); - //mqka[3] += dot((float4) mq[3], (float4) mk[3]); + deq_k_t4(pk + i/nl_k, i%nl_k, mk); - mqk += dot((float4) mk, (float4) sq4[i]); + mqk[cc] += dot((float4) mk, (float4) sq4[i]); + } } - static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails - - // simdgroup reduce (NE = 4) - // [ 0 .. 7] -> [ 0] - // [ 8 .. 15] -> [ 8] - // [16 .. 23] -> [16] - // [24 .. 31] -> [24] - if (NE <= 1) { - mqk += simd_shuffle_down(mqk, 16); - } - if (NE <= 2) { - mqk += simd_shuffle_down(mqk, 8); - } - if (NE <= 4) { - mqk += simd_shuffle_down(mqk, 4); - } - if (NE <= 8) { - mqk += simd_shuffle_down(mqk, 2); - } - if (NE <= 16) { - mqk += simd_shuffle_down(mqk, 1); - } - - // mqk = mqk*scale + mask*slope - if (tx == 0) { - mqk *= args.scale; - - if (args.logit_softcap != 0.0f) { - mqk = args.logit_softcap*precise::tanh(mqk); + if (NE == 1) { + mqk[cc] = simd_sum(mqk[cc]); + } else { + // simdgroup reduce (NE = 4) + // [ 0 .. 7] -> [ 0] + // [ 8 .. 15] -> [ 8] + // [16 .. 23] -> [16] + // [24 .. 31] -> [24] + if (NE <= 1) { + mqk[cc] += simd_shuffle_down(mqk[cc], 16); + } + if (NE <= 2) { + mqk[cc] += simd_shuffle_down(mqk[cc], 8); + } + if (NE <= 4) { + mqk[cc] += simd_shuffle_down(mqk[cc], 4); + } + if (NE <= 8) { + mqk[cc] += simd_shuffle_down(mqk[cc], 2); + } + if (NE <= 16) { + mqk[cc] += simd_shuffle_down(mqk[cc], 1); } - mqk += sm[NE*cc + ty]*slope; - - ss[NE*cc + ty] = mqk; + // broadcast + mqk[cc] = simd_shuffle(mqk[cc], NL*ty); } } + + if (FC_flash_attn_ext_vec_has_mask && + !FC_flash_attn_ext_vec_has_scap && + !FC_flash_attn_ext_vec_has_bias) { + ss[NE*tx + ty] = fma(mqk[tx], args.scale, (qk_t) sm[NE*tx + ty]); + } else { + mqk[tx] *= args.scale; + + if (FC_flash_attn_ext_vec_has_scap) { + mqk[tx] = args.logit_softcap*precise::tanh(mqk[tx]); + } + + if (FC_flash_attn_ext_vec_has_bias) { + mqk[tx] += (qk_t) sm[NE*tx + ty]*slope; + } else { + mqk[tx] += (qk_t) sm[NE*tx + ty]; + } + + ss[NE*tx + ty] = mqk[tx]; + } } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -4989,9 +5220,10 @@ kernel void kernel_flash_attn_ext_vec( ss[tiisg] = vs; // O = diag(ms)*O - #pragma unroll(DV4/NL) - for (short ii = 0; ii < DV4; ii += NL) { - lo[ii/NL] *= ms; + if ((DV4/NL % NW == 0) || ty == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[ii*NL] *= ms; + } } } @@ -4999,26 +5231,84 @@ kernel void kernel_flash_attn_ext_vec( // O = O + (Q*K^T)*V { - //#pragma unroll(C/NE) - for (short cc = 0; cc < C/NE; ++cc) { - device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + o4_t lo[DV4/NL]; + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + lo[ii] = 0.0f; + } - const s4_t ms(ss[NE*cc + ty]); + if (is_same::value) { + device const v4_t * pv4 = (device const v4_t *) ((device const char *) v + ic*args.nb21); - #pragma unroll(DV4/NL) - for (short ii = 0; ii < DV4; ii += NL) { - const short i = ii + tx; + pv4 += ty*NS20/4 + tx; - v4_t mv; - deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); + const auto sst = ss + ty; - lo[ii/NL] += o4_t(float4(mv)*float4(ms)); + FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + lo[ii] += o4_t(float4(pv4[cc*NE*NS20/4 + ii*NL])*float4(sst[cc*NE])); + } + } + } else { + FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { + device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21)); + + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + const short i = ii*NL + tx; + + v4_t mv; + deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); + + lo[ii] += o4_t(float4(mv)*float4(ss[NE*cc + ty])); + } + } + } + + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + if (NE > 1) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 16); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 16); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 16); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 16); + } + + if (NE > 2) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 8); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 8); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 8); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 8); + } + + if (NE > 4) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 4); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 4); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 4); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 4); + } + + if (NE > 8) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 2); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 2); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 2); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 2); + } + + if (NE > 16) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 1); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 1); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 1); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 1); + } + } + + if ((DV4/NL % NW == 0) || ty == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[ii*NL] += lo[ii]; } } } } - if (sinks != q && sgitg == 0 && iwg == 0) { + if (FC_flash_attn_ext_vec_has_sinks && sgitg == 0 && iwg == 0) { const float m = M; const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; @@ -5029,9 +5319,10 @@ kernel void kernel_flash_attn_ext_vec( S = S*ms + simd_sum(vs); -#pragma unroll(DV4/NL) - for (short ii = 0; ii < DV4; ii += NL) { - lo[ii/NL] *= ms; + if ((DV4/NL % NW == 0) || ty == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[ii*NL] *= ms; + } } } @@ -5042,63 +5333,12 @@ kernel void kernel_flash_attn_ext_vec( } } - // simdgroup reduce (NE = 4) - // [ 0, 8, 16, 24] -> [ 0] - // [ 1, 9, 17, 25] -> [ 1] - // [ 2, 10, 18, 26] -> [ 2] - // [ 3, 11, 19, 27] -> [ 3] - // [ 4, 12, 20, 28] -> [ 4] - // [ 5, 13, 21, 29] -> [ 5] - // [ 6, 14, 22, 30] -> [ 6] - // [ 7, 15, 23, 31] -> [ 7] - for (short ii = 0; ii < DV4; ii += NL) { - if (NE > 1) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16); - } - - if (NE > 2) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8); - } - - if (NE > 4) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4); - } - - if (NE > 8) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2); - } - - if (NE > 16) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1); - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // store results to shared memory - for (short i = tiisg; i < DV4; i += NL) { - sr4[i] = lo[i/NL]; - } + so4 -= tiisg; threadgroup_barrier(mem_flags::mem_threadgroup); // parallel reduce - for (short r = nsg/2; r > 0; r >>= 1) { + for (short r = NSG/2; r > 0; r >>= 1) { if (sgitg < r) { const float S0 = ss[ 0]; const float S1 = ss[r*(SH/2) + 0]; @@ -5120,7 +5360,7 @@ kernel void kernel_flash_attn_ext_vec( // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 for (short i = tiisg; i < DV4; i += NW) { - sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1; + so4[i] = so4[i]*ms0 + so4[i + r*PV4]*ms1; } } @@ -5133,21 +5373,73 @@ kernel void kernel_flash_attn_ext_vec( const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1; device float4 * dst4 = (device float4 *) dst; - device float * dst1 = (device float *) dst + nrows*DV*nwg; // the S and M are stored after the results + device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results - const float S = nwg == 1 ? 1.0f/ss[0] : 1.0f; + const float S = NWG == 1 ? 1.0f/ss[0] : 1.0f; // interleave the workgroup data for (short i = tiisg; i < DV4; i += NW) { - dst4[rid*DV4*nwg + nwg*i + iwg] = (float4) sr4[i]*S; + dst4[rid*DV4*NWG + NWG*i + iwg] = (float4) so4[i]*S; } // store S and M - if (nwg > 1 && tiisg == 0) { - dst1[rid*(2*nwg) + 2*iwg + 0] = ss[0]; - dst1[rid*(2*nwg) + 2*iwg + 1] = ss[1]; + if (NWG > 1) { + if (tiisg == 0) { + dst1[rid*(2*NWG) + 2*iwg + 0] = ss[0]; + dst1[rid*(2*NWG) + 2*iwg + 1] = ss[1]; + } } } + +#undef NWG +#undef NS10 +#undef NS20 +} + +template< + typename q4_t, // query types in shared memory + typename k4_t, // key types in shared memory + typename v4_t, // value types in shared memory + typename qk_t, // Q*K types + typename s_t, // soft-max types + typename s4_t, + typename o4_t, // attention accumulation types + typename kd4_t, // key type in device memory + short nl_k, + void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &), + typename vd4_t, // value type in device memory + short nl_v, + void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), + short DK, // K head size + short DV, // V head size + short NE = 4, // head elements per thread + short Q = 1, // queries per threadgroup + short C = 32> // cache items per threadgroup +kernel void kernel_flash_attn_ext_vec( + constant ggml_metal_kargs_flash_attn_ext_vec & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device const char * sinks, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { +#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C +#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg + switch (FC_flash_attn_ext_vec_nsg) { + // note: disabled cases to reduce library load time + case 1: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + case 2: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + case 4: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + //case 8: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + //case 16: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + //case 32: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + } +#undef FWD_TMPL +#undef FWD_ARGS } // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem @@ -5163,111 +5455,120 @@ kernel void kernel_flash_attn_ext_vec( typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; -template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #undef FA_TYPES -kernel void kernel_flash_attn_ext_reduce( - constant ggml_metal_kargs_flash_attn_ext_reduce & args, +constant int32_t FC_flash_attn_ext_vec_reduce_DV [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0)]]; +constant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1)]]; + +kernel void kernel_flash_attn_ext_vec_reduce( + constant ggml_metal_kargs_flash_attn_ext_vec_reduce & args, device const char * htmp, device char * dst, uint tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { +#define NWG (FC_flash_attn_ext_vec_reduce_NWG) +#define DV (FC_flash_attn_ext_vec_reduce_DV) + const uint64_t rid = tgpig; - const short nwg = 32; const short iwg = tiisg; - const short DV = args.ne20; - const short DV4 = DV/4; - device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*nwg; - device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*nwg; - device float4 * dst4 = (device float4 *) dst + rid*DV4; + device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*NWG; - float S = ss[rid*(2*nwg) + 2*iwg + 0]; - float M = ss[rid*(2*nwg) + 2*iwg + 1]; + float S = ss[rid*(2*NWG) + 2*iwg + 0]; + float M = ss[rid*(2*NWG) + 2*iwg + 1]; const float m = simd_max(M); const float ms = exp(M - m); S = 1.0f/simd_sum(S*ms); - for (int i = sgitg; i < DV4; i += nwg) { - const float4 v = simd_sum(htmp4[i*nwg + iwg]*ms); + const short DV4 = DV/4; + + device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*NWG; + device float4 * dst4 = (device float4 *) dst + rid*DV4; + + for (short i = sgitg; i < DV4; i += NWG) { + const float4 v = simd_sum(htmp4[i*NWG + iwg]*ms); if (iwg == 0) { dst4[i] = v*S; } } + +#undef NWG +#undef DV } template @@ -7397,7 +7698,7 @@ kernel void kernel_set_rows_f( const int32_t i10 = i01; const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0]; - device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); + device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) { @@ -7496,18 +7797,20 @@ kernel void kernel_mul_mm( #pragma unroll(4) for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) { + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(4) for (short i = 0; i < 4; i++) { simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); } - simdgroup_barrier(mem_flags::mem_none); - #pragma unroll(2) for (short i = 0; i < 2; i++) { simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(8) for (short i = 0; i < 8; i++){ simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index e8e6e1000b..cca86a8ce8 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6501,8 +6501,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v)); } - for (int hsk : { 40, 64, 80, 128, 192, 256, 576 }) { - for (int hsv : { 40, 64, 80, 128, 192, 256, 512 }) { + for (int hsk : { 40, 64, 80, 96, 128, 192, 256, 576 }) { + for (int hsv : { 40, 64, 80, 96, 128, 192, 256, 512 }) { if (hsk != 192 && hsk != 576 && hsk != hsv) continue; if (hsk == 192 && (hsv != 128 && hsv != 192)) continue; if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA From b0d52998b962bd2681c34bf52af993af79f178b8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 Sep 2025 13:56:51 +0300 Subject: [PATCH 18/46] cuda : fix supports_op condition for get_rows when number of blocks is too large (#15868) * cuda : fix supports_op condition for get_rows when src1->ne2 > 1 ggml-ci * ggml : add comment about ggml_get_rows ggml-ci * cuda : add FIXME [no ci] * cuda : update support condition ggml-ci --- ggml/include/ggml.h | 6 +++++- ggml/src/ggml-cuda/ggml-cuda.cu | 4 ++++ ggml/src/ggml.c | 1 + 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 058f4267f7..b7b472c56e 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1529,7 +1529,11 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); - // supports 3D: a->ne[2] == b->ne[1] + // supports 4D a: + // a [n_embd, ne1, ne2, ne3] + // b I32 [n_rows, ne2, ne3, 1] + // + // return [n_embd, n_rows, ne2, ne3] GGML_API struct ggml_tensor * ggml_get_rows( struct ggml_context * ctx, struct ggml_tensor * a, // data diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index a88b9f75ef..0c6bd36396 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3392,6 +3392,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_GET_ROWS: { + // FIXME: https://github.com/ggml-org/llama.cpp/pull/15868 + if (op->src[1]->ne[1]*op->src[1]->ne[2] > 65535) { + return false; + } switch (op->src[0]->type) { case GGML_TYPE_F16: case GGML_TYPE_F32: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f35c337952..50dc1aa24f 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3623,6 +3623,7 @@ struct ggml_tensor * ggml_get_rows( struct ggml_tensor * a, struct ggml_tensor * b) { GGML_ASSERT(a->ne[2] == b->ne[1]); + GGML_ASSERT(a->ne[3] == b->ne[2]); GGML_ASSERT(b->ne[3] == 1); GGML_ASSERT(b->type == GGML_TYPE_I32); From 56920f56651908d5cce7a310dabf54ac4f6fbb7f Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Mon, 8 Sep 2025 21:50:05 +0700 Subject: [PATCH 19/46] server : bring back timings_per_token (#15879) --- tools/server/server.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 70aa187564..39ef439e94 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -308,6 +308,7 @@ struct server_task { // enabling this will output extra debug information in the HTTP responses from the server params.verbose = params_base.verbosity > 9; + params.timings_per_token = json_value(data, "timings_per_token", false); params.stream = json_value(data, "stream", false); params.cache_prompt = json_value(data, "cache_prompt", true); From 88021565f08e0b7c4e07ac089a15ec16fae9166c Mon Sep 17 00:00:00 2001 From: Jesse Date: Mon, 8 Sep 2025 10:59:48 -0400 Subject: [PATCH 20/46] chat : Deepseek V3.1 reasoning and tool calling support (OpenAI Style) (#15533) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add DeepSeek V3.1 thinking mode support - Added COMMON_CHAT_FORMAT_DEEPSEEK_V3_1 enum value - Created common_chat_params_init_deepseek_v3_1() function (currently uses R1 implementation) - Created common_chat_parse_deepseek_v3_1() function that handles V3.1 thinking format: - Extracts reasoning content before '' tag into reasoning_content - Extracts regular content after '' tag into content - No opening '' tag in V3.1 format - Added detection logic for V3.1 templates based on pattern: 'message['prefix'] is defined and message['prefix'] and thinking' - Added V3.1 case to parsing switch statement This addresses the issue where V3.1 outputs reasoning content followed by '' and then regular content without the opening '' tag. * Another attempt by V3.1 non-thinking * Fix test, but it's not asserting anything. * Ignore vim swap files in tests dir * Update the test * Try using try_find_literal instead of regex * passing test * Revert "Try using try_find_literal instead of regex" This reverts commit c50d887ec2780dd9e6b8b397e92347d3db8d5575. * Remove unnecessary change * Remove comment * Add code to handle non-thinking mode. * Try to set message['prefix'] when thinking is enabled. * This fixes reasoning, but breaks normal content. We need state in the chat parser. * DeepSeek V3.1 thinking is now the default. Disable with `--reasoning-budget 0`. * Simplify (DeepSeek V3.1 reasoning) * Fix sign inversion bug * Add some tool calling code (not working). * Tool calls working in non-reasoning mode. * Attempt a unit test for tool call parsing. * Passing test * Add tests for both happy path and broken fenced DeepSeek V3.1 tool call variants. * Passing DeepSeek V3.1 tool call tests, but model is not working. * Revert assistance response prefill change. Not my monkeys. * Add fenced_thinking unit test variant. Passes, but thinking tool calling still isn't working for some reason. * Tests pass in reasoning mode. Also e2e tool test passes. * Make a copy of the parse_json_tool_calls function for deepseek-v3.1 so as to not accidentally introduce regressions. * Fix thinking_forced_open logic. tool calling broken. Need to add another test case. * That's what I get for cargo culting a newline. * Add multi tool call test for deepseek v3.1 non-reasoning * Move test, remove .gitignore change * Place deepseek-v3.1 reasoning test directly into existing reasoning function per CISC's request. * Address whitespace CI failure. * Merge two assert_equals per CISC's request. * Add DeepSeek-V3.1 tests to tests/test-chat.cpp per CISC's request. * Merge deepseek V3.1 and regular parse_json_tool_calls() function behaviors by adding optional update_cursor argument. * Update tests/test-chat-parser.cpp Co-authored-by: Sigbjørn Skjæret * Update tests/test-chat-parser.cpp Co-authored-by: Sigbjørn Skjæret * Update tests/test-chat-parser.cpp Co-authored-by: Sigbjørn Skjæret * Update tests/test-chat-parser.cpp Co-authored-by: Sigbjørn Skjæret * Update tests/test-chat-parser.cpp Co-authored-by: Sigbjørn Skjæret * Update tests/test-chat-parser.cpp Co-authored-by: Sigbjørn Skjæret * Update tests/test-chat-parser.cpp Co-authored-by: Sigbjørn Skjæret * Update tests/test-chat-parser.cpp Co-authored-by: Sigbjørn Skjæret * Update tests/test-chat-parser.cpp Co-authored-by: Sigbjørn Skjæret * DeepSeek V3.1 fix reasoning_format none * Strip grammar down to strictly what we expect based on model card. Throw out parts we cargo culted from R1 that don't make sense. * Update tests/test-chat-parser.cpp Co-authored-by: Sigbjørn Skjæret * DeepSeek V3.1 - Add edge case where thinking is forced open, there is tool calling in the reasoning content, but then the model just stops the output without closing the tag, so it's not a partial. In this case, use the tool call in the reasoning content. * DeepSeek V3.1 - simplify update_cursor * Update common/chat.cpp Co-authored-by: Sigbjørn Skjæret * Update common/chat.cpp Co-authored-by: Sigbjørn Skjæret * Update common/chat.cpp Co-authored-by: Sigbjørn Skjæret * Fix indent --------- Co-authored-by: openhands Co-authored-by: Sigbjørn Skjæret --- common/chat.cpp | 139 +++++++++++++ common/chat.h | 1 + models/templates/README.md | 1 + .../templates/deepseek-ai-DeepSeek-V3.1.jinja | 3 + tests/test-chat-parser.cpp | 193 +++++++++++++++++- tests/test-chat.cpp | 137 ++++++++++++- 6 files changed, 472 insertions(+), 2 deletions(-) create mode 100644 models/templates/deepseek-ai-DeepSeek-V3.1.jinja diff --git a/common/chat.cpp b/common/chat.cpp index a8a4c3e3c1..4707c4fef4 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -631,6 +631,7 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1"; + case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: return "DeepSeek V3.1"; case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro"; case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B"; case COMMON_CHAT_FORMAT_GRANITE: return "Granite"; @@ -698,11 +699,13 @@ static void parse_json_tool_calls( size_t from = std::string::npos; auto first = true; while (true) { + auto start_pos = builder.pos(); auto res = function_regex_start_only && first ? builder.try_consume_regex(*function_regex_start_only) : function_regex ? builder.try_find_regex(*function_regex, from) : std::nullopt; + if (res) { std::string name; if (get_function_name) { @@ -737,6 +740,8 @@ static void parse_json_tool_calls( return; } throw common_chat_msg_partial_exception("incomplete tool call"); + } else { + builder.move_to(start_pos); } break; } @@ -1388,6 +1393,71 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ } return data; } + +static common_chat_params common_chat_params_init_deepseek_v3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + + // Pass thinking context for DeepSeek V3.1 template + json additional_context = { + {"thinking", inputs.enable_thinking}, + }; + + auto prompt = apply(tmpl, inputs, + /* messages_override= */ inputs.messages, + /* tools_override= */ std::nullopt, + additional_context); + data.prompt = prompt; + data.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; + if (string_ends_with(data.prompt, "")) { + if (!inputs.enable_thinking) { + data.prompt += ""; + } else { + data.thinking_forced_open = true; + } + } + if (inputs.tools.is_array() && !inputs.tools.empty()) { + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null(); + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + tool_rules.push_back(builder.add_rule(name + "-call", + "( \"<|tool▁call▁begin|>\" )? \"" + name + "<|tool▁sep|>" + "\" " + builder.add_schema(name + "-args", parameters) + " " + "\"<|tool▁call▁end|>\"")); + }); + // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag, + // so we accept common variants (then it's all constrained) + builder.add_rule("root", + std::string(data.thinking_forced_open ? "( \"\" space )? " : "") + + "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" | \"<|tool▁calls|>\" ) " + "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " " + "\"<|tool▁calls▁end|>\"" + " space"); + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + // If thinking_forced_open, then we capture the tag in the grammar, + // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) + std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)" : "(?:[\\s\\S]*?\\s*)?") + + "(<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)[\\s\\S]*" + }); + data.preserved_tokens = { + "", + "", + "<|tool▁calls▁begin|>", + "<|tool▁call▁begin|>", + "<|tool▁sep|>", + "<|tool▁call▁end|>", + "<|tool▁calls▁end|>", + }; + }); + } + return data; +} + static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { builder.try_parse_reasoning("", ""); if (!builder.syntax().parse_tool_calls) { @@ -1409,6 +1479,66 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { tool_calls_end); } +static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) { + static const common_regex function_regex("(?:<|tool▁call▁begin|>)?([^\\n<]+)(?:<|tool▁sep|>)"); + + static const common_regex close_regex("(?:[\\s]*)?<|tool▁call▁end|>"); + static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); + static const common_regex tool_calls_end("<|tool▁calls▁end|>"); + + if (!builder.syntax().parse_tool_calls) { + LOG_DBG("%s: not parse_tool_calls\n", __func__); + builder.add_content(builder.consume_rest()); + return; + } + + LOG_DBG("%s: parse_tool_calls\n", __func__); + + parse_json_tool_calls( + builder, + /* block_open= */ tool_calls_begin, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + tool_calls_end); +} + +static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) { + // DeepSeek V3.1 outputs reasoning content between "" and "" tags, followed by regular content + // First try to parse using the standard reasoning parsing method + LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str()); + + auto start_pos = builder.pos(); + auto found_end_think = builder.try_find_literal(""); + builder.move_to(start_pos); + + if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) { + LOG_DBG("%s: no end_think, not partial, adding content\n", __func__); + common_chat_parse_deepseek_v3_1_content(builder); + } else if (builder.try_parse_reasoning("", "")) { + // If reasoning was parsed successfully, the remaining content is regular content + LOG_DBG("%s: parsed reasoning, adding content\n", __func__); + // <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>NAME\n```json\nJSON\n```<|tool▁call▁end|><|tool▁calls▁end|> + common_chat_parse_deepseek_v3_1_content(builder); + } else { + if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) { + LOG_DBG("%s: reasoning_format none, adding content\n", __func__); + common_chat_parse_deepseek_v3_1_content(builder); + return; + } + // If no reasoning tags found, check if we should treat everything as reasoning + if (builder.syntax().thinking_forced_open) { + // If thinking is forced open but no tags found, treat everything as reasoning + LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__); + builder.add_reasoning_content(builder.consume_rest()); + } else { + LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__); + // <|tool▁call▁begin|>NAME<|tool▁sep|>JSON<|tool▁call▁end|> + common_chat_parse_deepseek_v3_1_content(builder); + } + } +} + static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; auto prompt = apply(tmpl, inputs); @@ -2365,6 +2495,12 @@ static common_chat_params common_chat_templates_apply_jinja( } } + // DeepSeek V3.1: detect based on specific patterns in the template + if (src.find("message['prefix'] is defined and message['prefix'] and thinking") != std::string::npos && + params.json_schema.is_null()) { + return common_chat_params_init_deepseek_v3_1(tmpl, params); + } + // DeepSeek R1: use handler in all cases except json schema (thinking / tools). if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) { return common_chat_params_init_deepseek_r1(tmpl, params); @@ -2537,6 +2673,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) { case COMMON_CHAT_FORMAT_DEEPSEEK_R1: common_chat_parse_deepseek_r1(builder); break; + case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: + common_chat_parse_deepseek_v3_1(builder); + break; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: common_chat_parse_functionary_v3_2(builder); break; diff --git a/common/chat.h b/common/chat.h index 41851022d9..5170fc14f4 100644 --- a/common/chat.h +++ b/common/chat.h @@ -107,6 +107,7 @@ enum common_chat_format { COMMON_CHAT_FORMAT_FIREFUNCTION_V2, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, + COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, COMMON_CHAT_FORMAT_HERMES_2_PRO, COMMON_CHAT_FORMAT_COMMAND_R7B, COMMON_CHAT_FORMAT_GRANITE, diff --git a/models/templates/README.md b/models/templates/README.md index 2e8eaa5953..3a649b8f4d 100644 --- a/models/templates/README.md +++ b/models/templates/README.md @@ -22,4 +22,5 @@ These templates can be updated with the following commands: ./scripts/get_chat_template.py Qwen/QwQ-32B > models/templates/Qwen-QwQ-32B.jinja ./scripts/get_chat_template.py Qwen/Qwen3-0.6B > models/templates/Qwen-Qwen3-0.6B.jinja ./scripts/get_chat_template.py zai-org/GLM-4.5 > models/templates/zai-org-GLM-4.5.jinja +./scripts/get_chat_template.py deepseek-ai/DeepSeek-V3.1 > models/templates/deepseek-ai-DeepSeek-V3.1.jinja ``` diff --git a/models/templates/deepseek-ai-DeepSeek-V3.1.jinja b/models/templates/deepseek-ai-DeepSeek-V3.1.jinja new file mode 100644 index 0000000000..e5656196a3 --- /dev/null +++ b/models/templates/deepseek-ai-DeepSeek-V3.1.jinja @@ -0,0 +1,3 @@ +{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% if not thinking is defined %}{% set thinking = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + ' + +' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{%- set ns.is_first = false -%}{%- set ns.is_last_user = true -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}{%- if ns.is_last_user %}{{'<|Assistant|>'}}{%- endif %}{%- set ns.is_last_user = false -%}{%- set ns.is_first = false %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}{%- else %}{{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}{%- endif %}{%- endfor %}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %}{%- if ns.is_last_user %}{{'<|Assistant|>'}}{%- if message['prefix'] is defined and message['prefix'] and thinking %}{{''}} {%- else %}{{''}}{%- endif %}{%- endif %}{%- set ns.is_last_user = false -%}{%- if ns.is_tool %}{{message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{%- set content = message['content'] -%}{%- if '' in content %}{%- set content = content.split('', 1)[1] -%}{%- endif %}{{content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_last_user = false -%}{%- set ns.is_tool = true -%}{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endfor -%}{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool %}{{'<|Assistant|>'}}{%- if not thinking %}{{''}}{%- else %}{{''}}{%- endif %}{% endif %} \ No newline at end of file diff --git a/tests/test-chat-parser.cpp b/tests/test-chat-parser.cpp index 59e44e07d2..547ebb4871 100644 --- a/tests/test-chat-parser.cpp +++ b/tests/test-chat-parser.cpp @@ -15,14 +15,20 @@ #include "regex-partial.h" template -static void assert_equals(const T & expected, const T & actual) { +static void assert_equals(const std::string_view label, const T & expected, const T & actual) { if (expected != actual) { + std::cerr << label << std::endl; std::cerr << "Expected: " << expected << std::endl; std::cerr << "Actual: " << actual << std::endl; std::cerr << std::flush; throw std::runtime_error("Test failed"); } } + +template +static void assert_equals(const T & expected, const T & actual) { + assert_equals("", expected, actual); +} static void assert_equals(const char * expected, const std::string & actual) { return assert_equals(expected, actual); } @@ -46,6 +52,7 @@ static void assert_throws(const std::function & fn, const std::string & } static void test_reasoning() { + //common_log_set_verbosity_thold(LOG_DEFAULT_DEBUG); { common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, { /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY, @@ -99,6 +106,36 @@ static void test_reasoning() { assert_equals("Cogito", builder.result().content); assert_equals("Ergo sum", builder.consume_rest()); } + // Test DeepSeek V3.1 parsing - reasoning content followed by "" and then regular content + { + common_chat_syntax syntax = { + /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + /* .parse_tool_calls = */ true, + }; + const std::string variant("deepseek_v3_1_reasoning_format_deepseek"); + common_chat_msg_parser builder("REASONINGok", /* is_partial= */ false, syntax); + assert_equals(variant, true, builder.try_parse_reasoning("", "")); + assert_equals(variant, std::string("REASONING"), builder.result().reasoning_content); + assert_equals(variant, std::string("ok"), builder.consume_rest()); + } + // Test DeepSeek V3.1 parsing - reasoning_format none - reasoning content followed by "" and then regular content + { + common_chat_syntax syntax = { + /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + /* .parse_tool_calls = */ true, + }; + const std::string variant("deepseek_v3_1_reasoning_format_none"); + const std::string input = "REASONINGok"; + auto msg = common_chat_parse(input, false, syntax); + assert_equals(variant, std::string("REASONINGok"), msg.content); + assert_equals(variant, std::string(""), msg.reasoning_content); + } } static void test_regex() { @@ -186,6 +223,159 @@ static void test(const std::string & input, bool is_partial, const std::vectoris_partial); assert_equals(expected, args_paths.size() == 1 && args_paths[0].empty() ? js->value.get() : js->value.dump()); } + +static void test_deepseek_v3_1_tool_calls() { + //common_log_set_verbosity_thold(LOG_DEFAULT_DEBUG); + // variant: happy path for when it works as the model card says it should + const std::string variant("simple"); + common_chat_syntax syntax = { + /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ false, + /* .parse_tool_calls = */ true, + }; + const std::string input = "<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>"; + auto msg = common_chat_parse(input, false, syntax); + assert_equals(variant, 1, msg.tool_calls.size()); + assert_equals(variant, std::string("get_time"), msg.tool_calls[0].name); + // JSON arguments are dumped without spaces + assert_equals(variant, std::string("{\"city\":\"Tokyo\"}"), msg.tool_calls[0].arguments); + assert_equals(variant, std::string(""), msg.content); + assert_equals(variant, std::string(""), msg.reasoning_content); + + // variant: simple + thinking open + { + common_chat_syntax syntax = { + /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + /* .parse_tool_calls = */ true, + }; + const std::string variant("simple_thinking"); + const std::string in = "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>"; + auto m = common_chat_parse(in, false, syntax); + assert_equals(variant, 1, m.tool_calls.size()); + assert_equals(variant, std::string("get_time"), m.tool_calls[0].name); + assert_equals(variant, std::string("{\"city\":\"Tokyo\"}"), m.tool_calls[0].arguments); + assert_equals(variant, std::string(""), m.content); + assert_equals(variant, std::string("REASONING"), m.reasoning_content); + } + // variant: simple + multiple tool calls + { + common_chat_syntax syntax = { + /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ false, + /* .parse_tool_calls = */ true, + }; + const std::string variant("simple_multiple_tool_calls"); + const std::string in = "CONTENT<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Paris\"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"city\": \"Paris\"}<|tool▁call▁end|><|tool▁calls▁end|>"; + auto m = common_chat_parse(in, false, syntax); + assert_equals(variant, 2, m.tool_calls.size()); + assert_equals(variant, std::string("get_time"), m.tool_calls[0].name); + assert_equals(variant, std::string("{\"city\":\"Paris\"}"), m.tool_calls[0].arguments); + assert_equals(variant, std::string("get_weather"), m.tool_calls[1].name); + assert_equals(variant, std::string("{\"city\":\"Paris\"}"), m.tool_calls[1].arguments); + assert_equals(variant, std::string("CONTENT"), m.content); + assert_equals(variant, std::string(""), m.reasoning_content); + } + + + // variant: thinking forced open + tool call in reasoning content + { + common_chat_syntax syntax = { + /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + /* .parse_tool_calls = */ true, + }; + const std::string variant("thinking_forced_open_tool_call_in_reasoning"); + const std::string in = "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time2<|tool▁sep|>{\"city\": \"Tokyo2\"}<|tool▁call▁end|><|tool▁calls▁end|>REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>"; + auto m = common_chat_parse(in, false, syntax); + assert_equals(variant, 1, m.tool_calls.size()); + assert_equals(variant, std::string("get_time"), m.tool_calls[0].name); + assert_equals(variant, std::string("{\"city\":\"Tokyo\"}"), m.tool_calls[0].arguments); + assert_equals(variant, std::string(""), m.content); + assert_equals(variant, std::string("REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time2<|tool▁sep|>{\"city\": \"Tokyo2\"}<|tool▁call▁end|><|tool▁calls▁end|>REASONING"), m.reasoning_content); + } + + // variant: thinking forced open + tool call in reasoning content + no closing think + not partial + // This is a bit of a fine tuning issue on the model's part IMO. It really should not be attempting + // to make tool calls in reasoning content according to the model card, but it does sometimes, so + // add the reasoning content as regular content and parse the tool calls. + { + common_chat_syntax syntax = { + /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + /* .parse_tool_calls = */ true, + }; + const std::string variant("thinking_forced_open_tool_call_in_reasoning_no_closing_think_not_partial"); + const std::string in = "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>"; + auto m = common_chat_parse(in, false, syntax); + assert_equals(variant, std::string("REASONING"), m.content); + assert_equals(variant, std::string(""), m.reasoning_content); + assert_equals(variant, 1, m.tool_calls.size()); + assert_equals(variant, std::string("get_time"), m.tool_calls[0].name); + assert_equals(variant, std::string("{\"city\":\"Tokyo\"}"), m.tool_calls[0].arguments); + } + + // variant: thinking forced open + tool call in reasoning content + no closing think + partial + { + common_chat_syntax syntax = { + /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + /* .parse_tool_calls = */ true, + }; + const std::string variant("thinking_forced_open_tool_call_in_reasoning_no_closing_think_partial"); + const std::string in = "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>"; + auto m = common_chat_parse(in, /* is_partial= */ true, syntax); + assert_equals(variant, std::string("REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>"), m.reasoning_content); + assert_equals(variant, std::string(""), m.content); + assert_equals(variant, 0, m.tool_calls.size()); + } + + // variant: thinking not forced open + reasoning + regular content + no tool calls + { + common_chat_syntax syntax = { + /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + /* .parse_tool_calls = */ true, + }; + const std::string variant("thinking_forced_open_reasoning_regular_content_no_tool_calls"); + const std::string in = "REASONINGCONTENT"; + auto m = common_chat_parse(in, false, syntax); + assert_equals(variant, 0, m.tool_calls.size()); + assert_equals(variant, std::string("CONTENT"), m.content); + assert_equals(variant, std::string("REASONING"), m.reasoning_content); + } + // variant: thinking not forced open + missing reasoning + no tool calls + { + common_chat_syntax syntax = { + /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ false, + /* .parse_tool_calls = */ true, + }; + const std::string variant("thinking_not_forced_open_missing_reasoning_no_tool_calls"); + const std::string in = "CONTENT"; + auto m = common_chat_parse(in, false, syntax); + assert_equals(variant, 0, m.tool_calls.size()); + assert_equals(variant, std::string("CONTENT"), m.content); + assert_equals(variant, std::string(""), m.reasoning_content); + } +} + static void test_with_args(const std::string & input, const std::string & expected, bool parse_as_partial = true, bool is_partial = true) { common_chat_msg_parser builder(input, parse_as_partial, {}); auto js = builder.try_consume_json_with_dumped_args({{"args"}}, {}); @@ -347,6 +537,7 @@ int main() { test_json_with_dumped_args(); test_reasoning(); test_regex(); + test_deepseek_v3_1_tool_calls(); std::cout << "All tests passed!\n"; return 0; } diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 17ff7ea9c2..ac8a0ade1f 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -1757,7 +1757,6 @@ static void test_template_output_parsers() { /* is_partial= */ false, {COMMON_CHAT_FORMAT_SEED_OSS})); } - { auto tmpls = read_templates("models/templates/NVIDIA-Nemotron-Nano-v2.jinja"); std::vector end_tokens{ "" }; @@ -1828,6 +1827,142 @@ static void test_template_output_parsers() { /* expect_grammar_triggered= */ true ); } + { + auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-V3.1.jinja"); + std::vector end_tokens{ "<|end▁of▁sentence|>" }; + + for (const auto & inputs : { inputs_no_tools, inputs_tools }) { + auto params = common_chat_templates_apply(tmpls.get(), inputs); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, params.format); + assert_equals(true, params.thinking_forced_open); + } + + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + assert_msg_equals( + simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"), + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?", + /* is_partial= */ false, + { + COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + })); + // variant: thinking forced open, reasoning_format none + assert_msg_equals( + simple_assist_msg("REASONINGok", ""), + common_chat_parse( + "REASONINGok", + /* is_partial= */ false, + { + COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + /* .parse_tool_calls = */ true, + })); + // variant: happy path for when it works as the model card says it should + assert_msg_equals( + simple_assist_msg("", "", "get_time", "{\"city\":\"Tokyo\"}"), + common_chat_parse( + "<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>", + /* is_partial= */ false, + { + COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ false, + /* .parse_tool_calls = */ true, + })); + // variant: simple + thinking open + assert_msg_equals( + simple_assist_msg("", "REASONING", "get_time", "{\"city\":\"Tokyo\"}"), + common_chat_parse( + "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>", + /* is_partial= */ false, + { + COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + /* .parse_tool_calls = */ true, + })); + // variant: simple + multiple tool calls + common_chat_msg message_assist_multiple_calls; + message_assist_multiple_calls.role = "assistant"; + message_assist_multiple_calls.content = "CONTENT"; + message_assist_multiple_calls.tool_calls.push_back({"get_time", "{\"city\":\"Paris\"}", ""}); + message_assist_multiple_calls.tool_calls.push_back({"get_weather", "{\"city\":\"Paris\"}", ""}); + assert_msg_equals( + message_assist_multiple_calls, + common_chat_parse( + "CONTENT<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Paris\"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"city\": \"Paris\"}<|tool▁call▁end|><|tool▁calls▁end|>", + /* is_partial= */ false, + { + COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ false, + /* .parse_tool_calls = */ true, + })); + // variant: thinking forced open + tool call in reasoning content + assert_msg_equals( + simple_assist_msg("", "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time2<|tool▁sep|>{\"city\": \"Tokyo2\"}<|tool▁call▁end|><|tool▁calls▁end|>REASONING", "get_time", "{\"city\":\"Tokyo\"}"), + common_chat_parse( + "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time2<|tool▁sep|>{\"city\": \"Tokyo2\"}<|tool▁call▁end|><|tool▁calls▁end|>REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>", + /* is_partial= */ false, + { + COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + /* .parse_tool_calls = */ true, + })); + // variant: thinking forced open + tool call in reasoning content + no closing think + not partial + // This is a bit of a fine tuning issue on the model's part IMO. It really should not be attempting + // to make tool calls in reasoning content according to the model card, but it does sometimes, so + // add the reasoning content as regular content and parse the tool calls. + assert_msg_equals( + simple_assist_msg("REASONING", "", "get_time", "{\"city\":\"Tokyo\"}"), + common_chat_parse( + "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>", + /* is_partial= */ false, + { + COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + /* .parse_tool_calls = */ true, + })); + // variant: thinking forced open + tool call in reasoning content + no closing think + partial + assert_msg_equals( + simple_assist_msg("", "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>", "", ""), + common_chat_parse( + "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>", + /* is_partial= */ true, + { + COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + /* .parse_tool_calls = */ true, + })); + // variant: thinking not forced open + missing reasoning + no tool calls + assert_msg_equals( + simple_assist_msg("CONTENT", ""), + common_chat_parse( + "CONTENT", + /* is_partial= */ false, + { + COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ false, + /* .parse_tool_calls = */ true, + })); + } } static void test_msg_diffs_compute() { From 0a16bf52e6874369cb4fd2bdc4863ef984b688e7 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 9 Sep 2025 01:23:46 +0800 Subject: [PATCH 21/46] CUDA: generate_cu_files.py - add missing mxfp4 (#15880) --- ggml/src/ggml-cuda/template-instances/generate_cu_files.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 3428113dc8..a92a9e52cf 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -24,7 +24,7 @@ TYPES_MMQ = [ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K", "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S", - "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS" + "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4" ] SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. From e68aa10d8f3d26fdad5b912540362d79de5460e3 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Mon, 8 Sep 2025 13:10:07 -0500 Subject: [PATCH 22/46] vulkan: sort graph to allow more parallel execution (#15850) * vulkan: sort graph to allow more parallel execution Add a backend proc to allow the backend to modify the graph. The vulkan implementation looks at which nodes depend on each other and greedily reorders them to group together nodes that don't depend on each other. It only reorders the nodes, doesn't change the contents of any of them. With #15489, this reduces the number of synchronizations needed. * call optimize_graph per-split --- ggml/src/ggml-backend-impl.h | 3 + ggml/src/ggml-backend.cpp | 11 +++ ggml/src/ggml-blas/ggml-blas.cpp | 1 + ggml/src/ggml-cann/ggml-cann.cpp | 1 + ggml/src/ggml-cpu/ggml-cpu.cpp | 1 + ggml/src/ggml-cuda/ggml-cuda.cu | 1 + ggml/src/ggml-metal/ggml-metal.m | 1 + ggml/src/ggml-opencl/ggml-opencl.cpp | 1 + ggml/src/ggml-rpc/ggml-rpc.cpp | 1 + ggml/src/ggml-sycl/ggml-sycl.cpp | 1 + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 130 +++++++++++++++++++++++++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 1 + ggml/src/ggml-zdnn/ggml-zdnn.cpp | 1 + 13 files changed, 154 insertions(+) diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index c36c12d657..2db5c4e0f8 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -114,6 +114,9 @@ extern "C" { void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event); // wait for an event on on a different stream void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event); + + // (optional) sort/optimize the nodes in the graph + void (*optimize_graph) (ggml_backend_t backend, struct ggml_cgraph * cgraph); }; struct ggml_backend { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index f615ab4bee..7646f3f134 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -463,6 +463,13 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) backend->iface.event_wait(backend, event); } +static void ggml_backend_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + GGML_ASSERT(backend); + if (backend->iface.optimize_graph != NULL) { + backend->iface.optimize_graph(backend, cgraph); + } +} + // Backend device const char * ggml_backend_dev_name(ggml_backend_dev_t device) { @@ -1298,6 +1305,10 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra struct ggml_backend_sched_split * split = &sched->splits[i]; split->graph = ggml_graph_view(graph, split->i_start, split->i_end); + // Optimize this split of the graph. This needs to happen before we make graph_copy, + // so they are in sync. + ggml_backend_optimize_graph(sched->backends[split->backend_id], &split->graph); + // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split for (int j = 0; j < split->n_inputs; j++) { assert(graph_copy->size > (graph_copy->n_nodes + 1)); diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index aeac2e5744..cdfc5a9bc2 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -270,6 +270,7 @@ static struct ggml_backend_i blas_backend_i = { /* .graph_compute = */ ggml_backend_blas_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; static ggml_guid_t ggml_backend_blas_guid(void) { diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 2f9f373f54..2e47ad90af 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2690,6 +2690,7 @@ static const ggml_backend_i ggml_backend_cann_interface = { /* .graph_compute = */ ggml_backend_cann_graph_compute, /* .event_record = */ ggml_backend_cann_event_record, /* .event_wait = */ ggml_backend_cann_event_wait, + /* .optimize_graph = */ NULL, }; /** diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 0c15165f1b..2b81f8b9af 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -190,6 +190,7 @@ static const struct ggml_backend_i ggml_backend_cpu_i = { /* .graph_compute = */ ggml_backend_cpu_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; static ggml_guid_t ggml_backend_cpu_guid(void) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 0c6bd36396..efca2c775a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3135,6 +3135,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .graph_compute = */ ggml_backend_cuda_graph_compute, /* .event_record = */ ggml_backend_cuda_event_record, /* .event_wait = */ ggml_backend_cuda_event_wait, + /* .optimize_graph = */ NULL, }; static ggml_guid_t ggml_backend_cuda_guid() { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index eeb6c9d4b7..e76fb71263 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -6275,6 +6275,7 @@ static struct ggml_backend_i ggml_backend_metal_i = { /* .graph_compute = */ ggml_backend_metal_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; static ggml_guid_t ggml_backend_metal_guid(void) { diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 727163b7fd..b188c5af34 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2838,6 +2838,7 @@ static ggml_backend_i ggml_backend_opencl_i = { /* .graph_compute = */ ggml_backend_opencl_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; ggml_backend_t ggml_backend_opencl_init(void) { diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index e84ff93efc..d4833068d0 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -795,6 +795,7 @@ static ggml_backend_i ggml_backend_rpc_interface = { /* .graph_compute = */ ggml_backend_rpc_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 877fbf7e86..619ccaefc0 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4063,6 +4063,7 @@ static ggml_backend_i ggml_backend_sycl_interface = { /* .graph_compute = */ ggml_backend_sycl_graph_compute, /* .event_record = */ ggml_backend_sycl_event_record, /* .event_wait = */ ggml_backend_sycl_event_wait, + /* .optimize_graph = */ NULL, }; static ggml_guid_t ggml_backend_sycl_guid() { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index e6245d0cfc..f0fa9e668c 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -583,6 +583,7 @@ struct vk_device_struct { bool disable_fusion; bool disable_host_visible_vidmem; bool allow_sysmem_fallback; + bool disable_optimize_graph; #ifdef GGML_VULKAN_MEMORY_DEBUG std::unique_ptr memory_logger; @@ -3592,6 +3593,9 @@ static vk_device ggml_vk_get_device(size_t idx) { const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv("GGML_VK_ALLOW_SYSMEM_FALLBACK"); device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr; + const char* GGML_VK_DISABLE_OPTIMIZE_GRAPH = getenv("GGML_VK_DISABLE_OPTIMIZE_GRAPH"); + device->disable_optimize_graph = GGML_VK_DISABLE_OPTIMIZE_GRAPH != nullptr; + bool fp16_storage = false; bool fp16_compute = false; bool maintenance4_support = false; @@ -11853,6 +11857,131 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg UNUSED(backend); } +// Sort the graph for improved parallelism. +static void ggml_vk_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * graph) +{ + VK_LOG_DEBUG("ggml_vk_optimize_graph(" << graph->n_nodes << " nodes)"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + if (ctx->device->disable_optimize_graph) { + return; + } + + auto const &is_empty = [](ggml_tensor * node) -> bool { + return node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; + }; + + auto const &is_src_of = [](const ggml_tensor *dst, const ggml_tensor *src) -> bool { + for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) { + if (dst->src[s] == src) { + return true; + } + } + // implicit dependency if they view the same tensor + const ggml_tensor *dst2 = dst->view_src ? dst->view_src : dst; + const ggml_tensor *src2 = src->view_src ? src->view_src : src; + if (dst2 == src2) { + return true; + } + return false; + }; + + // This function tries to reorder the graph to allow nodes to run in parallel. + // This helps with small batches, but for large batches its a slowdown, probably + // due to cache contention. So only reorder if the majority of nodes have few rows. + int num_small_nodes = 0; + int num_counted_nodes = 0; + for (int i = 0; i < graph->n_nodes; ++i) { + if (!is_empty(graph->nodes[i]) && + graph->nodes[i]->op != GGML_OP_SET_ROWS) { + if (ggml_nrows(graph->nodes[i]) <= 8) { + num_small_nodes++; + } + num_counted_nodes++; + } + } + if (num_small_nodes < num_counted_nodes / 2) { + return; + } + + std::vector new_order; + std::vector used(graph->n_nodes, false); + int first_unused = 0; + while (first_unused < graph->n_nodes) { + std::vector current_set; + + // First, grab the next unused node. + current_set.push_back(first_unused); + + // Loop through the next N nodes. Grab any that don't depend on other nodes that + // haven't already been run. Nodes that have already been run have used[i] set + // to true. Allow nodes that depend on the previous node if it's a fusion pattern + // that we support (e.g. RMS_NORM + MUL). + // This first pass only grabs "real" (non-view nodes). Second pass grabs view nodes. + // The goal is to not interleave real and view nodes in a way that breaks fusion. + const int NUM_TO_CHECK = 20; + for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) { + if (used[j]) { + continue; + } + if (is_empty(graph->nodes[j])) { + continue; + } + bool ok = true; + for (int c = first_unused; c < j; ++c) { + if (!used[c] && + is_src_of(graph->nodes[j], graph->nodes[c]) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL)) { + ok = false; + break; + } + } + if (ok) { + current_set.push_back(j); + } + } + // Second pass grabs view nodes. + // Skip this if it would break a fusion optimization (don't split up add->rms_norm or add->add). + if (graph->nodes[current_set.back()]->op != GGML_OP_ADD) { + for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) { + if (used[j]) { + continue; + } + if (!is_empty(graph->nodes[j])) { + continue; + } + bool ok = true; + for (int c = first_unused; c < j; ++c) { + bool c_in_current_set = std::find(current_set.begin(), current_set.end(), c) != current_set.end(); + // skip views whose srcs haven't been processed. + if (!used[c] && + is_src_of(graph->nodes[j], graph->nodes[c]) && + !c_in_current_set) { + ok = false; + break; + } + } + if (ok) { + current_set.push_back(j); + } + } + } + + // Push the current set into new_order + for (auto c : current_set) { + new_order.push_back(graph->nodes[c]); + used[c] = true; + } + while (first_unused < graph->n_nodes && used[first_unused]) { + first_unused++; + } + } + // Replace the graph with the new order. + for (int i = 0; i < graph->n_nodes; ++i) { + graph->nodes[i] = new_order[i]; + } +} + // TODO: enable async and synchronize static ggml_backend_i ggml_backend_vk_interface = { /* .get_name = */ ggml_backend_vk_name, @@ -11868,6 +11997,7 @@ static ggml_backend_i ggml_backend_vk_interface = { /* .graph_compute = */ ggml_backend_vk_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ ggml_vk_optimize_graph, }; static ggml_guid_t ggml_backend_vk_guid() { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index aad27bf605..df6a3ed95a 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -665,6 +665,7 @@ static ggml_backend_i ggml_backend_webgpu_i = { /* .graph_compute = */ ggml_backend_webgpu_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; /* End GGML Backend Interface */ diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp index 7507a52aea..a4c51ab4e7 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp @@ -586,6 +586,7 @@ static ggml_backend_i ggml_backend_zdnn_i = { /* .graph_compute = */ ggml_backend_zdnn_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; static ggml_guid_t ggml_backend_zdnn_guid(void) { From fe1c92cd7bb491e9c5767c9de413f2b7dc1e832b Mon Sep 17 00:00:00 2001 From: j-k Date: Mon, 8 Sep 2025 19:57:01 +0100 Subject: [PATCH 23/46] media : add llama1 icon (#15878) Add svg and png based off llama1-icon.svg --- media/llama1-icon.png | Bin 0 -> 16045 bytes media/llama1-icon.svg | 87 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 media/llama1-icon.png create mode 100644 media/llama1-icon.svg diff --git a/media/llama1-icon.png b/media/llama1-icon.png new file mode 100644 index 0000000000000000000000000000000000000000..0e44672e54bf3cb91a09d1a2469918e7f6333b3f GIT binary patch literal 16045 zcmc(`_dnJD|3CgfR(2AZNysLfjFOd*m6es1z0R?Zk(pWYM0WNj>u}5>viEjS!m*We z2;p=0djA98-@cbiF6GgE-tM>i{eHV$uh;wA=US>1*BP%v5JaJ__Cyzg@UT}uBt+oL z!1Qtq_>0tC&DaZqSfZ|e@DkU4+k%g`Vai4@JvTd;kCmq_Q5f&`)2NLdn_L!j|Tgpa7{$Z|3L zRVXJk4q4uWsz4aRbPR3bgb@aYI{Yvt?)sY~>@MR0cI#Tr=SzoY98Bcz9^kUQ+Ba~4 z4+HOz+Q<*K9k*^Q$Xg6}FjMH8nwqvg|2$mxD^ozD_h~w(G(5>bbVcDB1hqy^a2q?Q zmvajIXRW~SneW&K2Kz;H%56kIiY?VE*5Wfg;^u7A&^>>*O$b2*Qv${ce5HSIEKCwG z=1n$v=Uho7a^{;xp^!;}SC~zNYW*_^5{9(YYumFDd*wR-eh4?2i(plkVq(a|e5^Qc)4 zM3W#qowjWgDlF8l8|2G?9`4)rj?2(|N$Ub>|tJ{R~lloNn4rc_G(r@SEH z8t8o~)cRm;X>rNCIL0fA1gux9pq1%S)@sjCz}oCzzR>LZ$?Yp@E_;(;YD5nH{By@A z=fT#}3HhXzl`d8e;;Iu6bnxG+Z;kWFN{f!Ml!E9TZR&Lh`m5DlrZ3N~{=&)0sTK3O zrGUuC2_J%#M*oWr=bn#@Gfp40qO8rA^@~cjXDEetYfyTA$zmD>GcJsY{f6I3LyaK`=V@Rca!PpZ>Sg4W@zv3dq_mK$-~iev z`D3ffC#!+-v3}HHlrfQT@dtB@9162(c;*6M6vSsx|6Nn~p+~vi{gpl$*Ap=$PsKdu zyP5IxDyh@WvWg4b#~>X2$zSMz2%n|h-gsQZ`m2nO=d)tXxUIjRX_S@nFX58dWL$!R zx0422J{;QP82zDmeba8M4FPNce7MfuKVy&n?Dg+`)Tyo>0n8 z0uq?-{7zAgJmT@w4ej`5vwrp=ho9EhIuQ@wuBZ#}n6k+U7Wxdv@!148J-IN!>70t< zxwC~*YO99}Aby4?FDXctMr(UMFtoGafsSb~2XTb7cTcxA;puy!L#|<;OH-|#pheZ( zE_j2E#wE#tdRnvk+5h%zIBs4*o~q^0kI zJcHWS89lwvepC|D89fF)*gZ7(L8j-Zbb1m&k#4e4ZMGA2xPbG!>P7vUOqtC^TJ*)) z8~^yXXMtb-K_|8O#1ga^s5Mlvx5RYv#u(MG!*k8O!KR~#Civ6N{effB5=ML$4lWzh(_v${qc|q>#0kE} zi*hkIP}#6DO&_D4&MbPjE~YHjCSUQH)O0nev8Q`iY45}L6z1I_gL+w09s8<0_-5@i zEEVf%XzYY~nt20Kp$6mfTI~~;*)Azru#GEYYfUq-fF6IZE%4K%$@0_JcYDZNw%V%_ zS>G%|(+JboR8yVWIZGY+_-|`@dE9V1iG@?{b6ak&&&~SJQH=qe&y5B2OoxJHtR~H* zIjVwX>xOq9s2I3?8G9ppTvBYgr)qaGTT|tzjrlp9*27iS;auC+k+)zmj9JbZ(LZ-0 z2GIy0t!YPn-w{qNFg3J`CqJBzE-bg>!X|w581ioJz35J^?8wd;N2b_GDMREp zLKCdimH81HgZ|4Fq5k*U2vrpV$0~HJtrw>oRJpMi&7|5{Y9tG{tSHIxK-rU@UridO zGFnJoJKTi!4GpV)df(>AlW^S3?=8Z%=hHs*uhH(@tGK;vpF2+iBcJCuJ`ksV63 zQ4Kb}#N}Q3G@7Ph8*h7@cSue))ym61PjnG#GQ(fdxF!=>z-ha^9qWh|%;#ck zL~q3P`gVEo)K<6$4N^lo?@ZDoj79h+DEg~Bc1xcgw*=c&Go3RYuby686tw=%ui-ya z5$^i(dH$2xv+9{WXXRtpd!wPB?u?4OcFUld`6ytQBe}e

eAJdvhgFw|}}LOQ(2( z>tv$=_`(b2l+jRj_EC}h@gAm^^|?K}56WKso_@o+v^9cTl#ENnD2$XIV@*z+aP%za zHiehRypZcvaY@0Fv9zY($FEBI=~ORk3r2PyME=^EBSkn^RqiL;2ntqK2_TnNhnm2K z3CS9}F|}Zq)=_Sjv}cqx(LbEDU`zh=dm}c2kjG4vd()dItWk0ULoAdAA7(3&z-J-2 z(Rds`9F~j3V75naEoMErPn8=XzUHc=@o{+zq?=O={ryCVYc&dN}mQ^&VQrE?xQttWPhoHLU$paYat+ubk55 zC$r-&^kT2ST-(nPQb_asHm2}!X?x+rwUE@m&I{y=2V<+cjxhn3i0O5I_d132ba1&y<`@*$ zai4h&y-b=e95`Lq4c;|?TH|PBD7R_33vP;E?$Q0nn{Y9w+B@lLVzjzI;XdA8dX%WJ zn4dXfx`l84ICu5k{MKU>EInAsS~%#R?u+u4e*cDS9YJa8wf zu#bAEu_>?@*0@~E`}wrzxe4HURhsG&Vz%nI3xUHJO&=%Y#B*54uKf#9?0wWQ6@k6Q zroBnoW0?KB*K9+Ssq1ua4tM?JZ&Zcuw=wQCM`9Ji(1ftl?zdkUva0uZH#f@Lf{A2q zf5x8IZR<1rK6dgvtc}ZF>eH7MW+Z{NYMGj;xj9huVE*M~@96cNg?XXtYhLV9T*xQa z+Nh36YA}Vc(O`VDwl_Eta+R{rgBRb;STD2&XUrSylA4X}?X#ZB5vsD8F6KXAu?f*p z%3U>JTy3nv`uUPSDZPu>bl7Lpb1Em>j3ql?F5DMQ#J%>)(Uv)aao58uG*VuFdlTNCncJif_{@QFYsmKvvEg^H)|-}-Z`e-^Ke+vLoFEfYLb%vDMf zE9R-N?may-np6z5H|`ksb4wgvmf)z;<32rfJv`sr`@Y?1uypZ-RG1=^QhhvtpvJqM zPj9S}uZ`tF1C2(GT;FobyyQ_bu2eN(QILRy8K32;GQXJ&`d_}ILZHa%NlJc>G|$>L zf;^%Y`yJOA}DrVQ~TKAgw>=4 z-}TPur*ivK*(XFlH-zk`R)6cPDrCRh-o7}Ya^b|v6C47S6MQ^_S>ToWC%khQ4IFPTCw$;sOqk8j-i z>%;M~`6nhnHel)_#}J+*4F=E7_-uMa(`UvXSVJi7#>797i(L|*ftWbf<-KzaS7-v1 z@&5BSx@a;III+HT0&Jga|A+?34^H7zB*)@B%=q%+;r4n6JvAt345@4k93|JcEw{wm z2&UaQ++~f7;XrJ?XI&e3)<9scm`kibo3gZc%c~_$X9G8Ovdp@_Ja!V?mNqsaE1gH= zOBl-fyjUx)jNR^H;Bv)opXl`VhrINeosluWA`E8T_S`iUZv|Rc81r~ek4#sP*%y%U zFCA#{$qx2y4;Jb=0oV9ZZ8^)qO?DCHmiBE^W<29o*zWr?tp{p9GJTbBRX@@kp^#W=lJi$9r?65s1@1)7pH@+&&*Q zOde5TSz;#<9`DA*p`In(`(fg%m8pMmA{gzQqsm5zA+YdaPv6C2w78M^E;{HUxrwXj z-^pjbx2@(R!=o|_vAWS?)923>q6C!a*4bD+Ne!z;XhawEdBNQpg9lSp*nSMvrx+|n-Myo zzvi#V@A!M)uHw^uzmljHTlG}U~BfCa-x@4OMbF+hU+)H zqxLtMajzSy%F)nuGE=ikT}BsGWzo*cbreY)#3aaIf&Gdu_Nbzb#9Zu!dMJwb-fldpWo(} zbdxlFJv(>#yYHQekt*@Scxm+K`b18lv7h{7LMy$sudn=8?bHUW{Om|Tm82eH7)<~u$!=)I093}LgY zSQ=d@+3!Oh zq0Cr&NRJvrl%7@}Z{Eui^PyK&EI$ZlFOWpTCyY#L3wG@X$xka$sB5l-YmYz6KvLDm z4-8Y)BE>w+X)PP};*l?1D#=!GBrt*oA{$^J!$xiu_^#yZ_`s^H53rVN5b>Q+iYK{j zFL_^WvWTj>CCf`%a+3*WGQtZDX`E#)q%TOwQ}TiaorNJ+|^jhNa=H@4J)pvK4cgdTI-JuAN#Te$7rsBZabb zY0Laae1#Zjk+>8I#vuqC!%TLs#$RM|I7NvUN)%`m#qUUOO+z{SzL$$*a&nNRCGbUe z8y8D?PiMy`iD{aZWP7R(ATvu~8aShLH+9c>9Qi8htg~J+9R6eSeFm>@a7ha#=6-F}f`G#Zw>poMvfe9O1qMrUw^e4(1 z?)^PjMWg>kcgj<`QcdmYK9i?Y^>;SaL3U)0Z|l`6MgUN5_zP^UDw+!rreVd#KT|?Vk8+vv*tbT+ zvGzW6xI#{Zf4aCQ#Bd8V5@f1mF<~|?nY}%OW&wx^9>j6gO8j5z@ODwX>6z>y484;V zua`X!6vuGYsR)u!s#RC{bU(HBbiNL;C2->xBs>o0Rq*iPV~f^?+!494a=3#;D_-Yo z_E8ENGCwy8gqMY}r2JTMfW&7l`IjK@h3kMnrf?CBv3EDAjH_GstOj(qLJ&6-R!fQx zQa`?fSuL~Fo-6ZYRo~C@Nfy-C|K(`h$#4zD-YTXjWWaoWtRan9o0G?hwgssHIPR^w z>FS3@u-DSxU3YkVw?JLIv4_)Hyw_z)SzxNyIexlL_G4Z?%*x{L{$O&0sOi}BQjE7< z|NMI%r+{GPg|mVAj8*ux)&xb#a6^2At-X4Vf4%B}_XLAV(&NU$KJmtRrpEvvbTUDx zi86h4EGH`!*}I5Mi8Fet!ms&Ey^c{S>fZ*~mmnQ$sA{RW0$PrfkIu!-bVL8U*+~nG z^W>-0i}1*t@RemCMhBp&-}hO!k&4l#d3YCFX7Ci?Nxr*OfyF2tO0CM-o{>0 zK2MueVD>R4Wru&OG>cHKxw6Cn%~|z@P=z7XHQN2ZE?Wb8b&&m8V>x+2Ol2~8V=7Nw zCrh3ukBFeUq%y@r-Sp}_9&$Nj>^Q+-=)Pcq0;)1MQE*S>mE&Y%!34K5ZhF%r zOM$ar9~HHW5!%*fIy^N~HP`UIW$Z&t83rn%oO%-axhlbq*{VaF?#0CIvlY?mH3|Ya zr^TK#P3!p6?m6ALzl)Yn#W2y<6GGwVz|B=SMvvE-(Fe~>5?v1WZwBz&Y_{^~n91># zf-Y@7i=asW8iFf_`=J{oishs)M;P%qLP2K(`z+;?BwEcv%X*W15EX_B7FjHx%E~wxlLTF2w+VXhi6h_U=!6OC!C*6|t*!`>{8ZH7?gyy#7I2jx z*-HDvhOl!2IsaEJ!P5gDXf$S+vqo`;NsJkm#WLpm0O>v7AiFKFWO8s`q_v{PFwF2u zs&19ZzaO=}JRRWa_=(c5itt3Slsvs-fPg zDEwHmd)3X2bzG;S*yk25^sEKpu?z?1MFX%<^6>V1>(}1)d??(vilEj9`R1CRu)ubX$F*MM<;xmwV0UfL=d(Cxf4Dv5 zq!1a6)n!&aMQ)tohBnSrk-1jsa}3iP^plvul#+$QPD>DdZm)B}rUsjWY9ExFzOa${ zL$rK%ltPEOGpG@uJjHCgQ-BkQ4|=3DNnMBgy{hNGKjsl3t(w3KfKs%(v+jPzhf(bp zb~w;&#pf;`Fh4Ad6?^ORl#8uTtmqaIi^3S9-o^r-5iN`;cai>xEsuNOg=L4Q-AhTK z?HLy#Te*(jYqN>m#q5-l>GzWhspMvN=Y(?r{vageI;mMaePjI~9E9x_4+~#VSlb) zjFHuUW|UQgnSk@2b3_2BQT1XmU%~D7wAdNluE+Vn*Lj+@Iu`c0zW6MUM*qk{?*O*MLca1vs#a`^*>&Mp?FTl?AsNzmQ;)%AjaJn`Qh~lp}XXMP6|gIiIg&60u08An67t3}M*H z!rRCJi1IDBkV^L(QX0hmj(@zSa4)4O!imSQ;_Cu{Xcqvohgid!MD%$wIO@3>EBL=3 zlWvVy?0pANGN0Z;>~C6K*HrwtvW)CKJuArYCV?yH>YkM;YI9Pi&FR+0fK}sF5OPMb zw7Cm0Uk^=qqyg~z%N_qLxc(#bM`|yre#37h$t9ftwy^PR%j2MwjOR$uWP-A zIh%TJ@3Y(w=YpAFa?BNy#oq`W^4pV%a{JxgQ!dskaK7`ENVCdHQua;ZOAo9N zvtC#tKv9)lx&eGGQPLJ#Nk_s8iZ9b(-reK;ngaN)umpz$5Enm(1AgQ3=wZIsU#3t> zU~rsrgNz@(Re~kT&$#1geU^kcm}B%)u+(RjRISx%v1_zGSzR~zQ-~I2?YUrGpVg5A zop+LX6Ha3jYB^OS`rDblj7P*I(U^`A1U;!vGy#e8^*Ls_hnc~M$d+L%B7C_{tt^JA zj_eoP-nXSXtDY%uk8`?g@i{rUDKCiHy6jiKHZ016wdcXU=q%rnO#>){nlI1>8_zC_LaRl-_gy6SK7uxaQdk+GuHpS^TE;avTrOq6YDwiB`PnaVIM!+IG%mT zG48duo*H0r)6LmZN(lmc-?ty!1XE|e<_}G&2%P_IC7SA0J9OL}ls2iY%jHJW<0cdG z%(r;U^gUY2I@~i6PJJ2v11Pr%3&!$E6#*7mt~|hd!G1PH-+C5TBJoP)>+;n5iNluW zZRYhFlcL5yhYle|DcCNrZTOdDz+x~|{#4WOp@(0tzhPa>oQcii|s5VY`af=d}!N47PGd-yeb z*=$m1<`?!Xq4nt3L$O42TQG#Lq%>`X0!6VVtJ!U{_8^f}?wFOu3`K;@01!+>#CAK; z&E?M!N&hy8UoU)I%#qc@jt1Pjk1hC|HS8O^u}4Qu!x=a9CY1~c0@pvAquY6bFLXu& z^zV+dMu)#%z8148chiq6sGV5I23y`Bz4kN^MwBPwx?Q^|auu=|$M12Sl?^8eHkHS2 z>sZ|bD4M*N?o&2ti8b%r?MJ^7^4gJk7F4c+(fkZ3y(?HFGkiGb2HbXYV@m<%PcA zwEA>kzUNcV~LjLBET_*9VVfI%yq>%#ORhy z-C?IfM+|fHkQ6llUjU@W+S>Ylx{3gcX?!DpXjCn6H2S)muczAu@uxc(FzWI zirPgE98?5Itidco(lp{+GmE>B=OJeu7Cf%mj%umGkWK^GfI1g>~a5>8bD%UyY11cWy z=VJl;>Q3Lxi_F4)#vC=t2L(Q*|A*3RTZAV-cSjL>O!kvoWk*2A;@h9Tm$c-iovvd7 zeS4hp$ClNBgv-_R@UwL~2g;MkDkyiun+`sNKyMWc+25gACwYL6F6P%b{rQ!#*}bsd z9w1QW!G{~HtXmNNTgBs2jrpH44uRdf4<4=_v-q9kawd4k_=pr0KM$azKx&$pqo~?+ z&N{CD<{Ih_3EdMEEW0Pc4wI9T!=g?EiLbWapOMTsoEn3qYcOD3c6 z<3ds|h&psnc3)QJlRA%iV=s~(hCX+Y=)~*1WWODFkM8sT-5Bm@FeQ2SV3Xz>R+P{g z4MyJcyfwV1JM27Z7bI=%0TWxFgk+)Cr!^B_RFgpDK16yn?+1JaEK5Uas>8AcPGHk0GSPSPQ1yuA#Y_o9<@$vrV(p5&G&Gf8fCz7X!I zqjlzUoRS=`p^+4ENa)>BxiUe=%|DyKpWVN0WMiiG_`66T26B9KepUL%=MQeQ{x}KC zI7B(gZ4BXA)`0$M`8VZMiT6uP+{VV16{f-23IQzvia1L`q|5=%jvOdK2`Qk9UCsg3 z)%lzn!wRzvIG7&JK^shJ8`^?k2MUZ;JWkW^9#_`0c1_hW0F&32ALo+*#Ybr{tcOL@ zMk^9sAlcJ$LTLb|3u;%%4R|(Mv~U)a;A%dwpEt+&?(Vec~ifGqEUL3pa0)<|?APBw$q26n;;Lf!-Wch{rd*+~*^ zRAtg1rC=8*WxnEds*%{Fp8pIGs>Y0pF5< zA<|ybj<*)AECyWc7M_c}#|xnWb>9a)OkdH>F++)8WTmpSxBR^d6X z#|{`IbLd_zs-%#%htS&R45as$b&+GzP#h6H+BViC?>cco#vqH3y|fP{lpn6>GiNOD zCC$}`c{5=Oa9Ec0)EGh8)e=i)LsNiUQOMbX@E5YC!3d^LG3aXWG6;QfDaXu#syqx1 zKXvR`=;Ahy04a*@kiN2;r5HY1nT){GU~9y1L#m7j%8yg@dF!Ii9Z%Kc(aUWz*%WjsU z%#c0-s3sXF864j3?DMv61A5MR4MoN^6E4OBZHxrZe1XYA`mraaKOgnRk@lcBj-72# zZH#~fjSC`TxP3+FqHo4=cI2{;nz;|Ay<6anLk%Hp;lgP7t*g6tW<&!BssalD(li{) zND3zaQ%;+4z92#TPI6%hGvzvbjqUH8GyJ>N2E*UjMN z$r1`VPQ&&KcMFBq?32EqBC8+R<2BRdM0Bg9`~zD~II#M&NAC!+fD#d2D{cm3WILUv#@WEvT$=?$pR*RW1}+OGu! ziHnN4hB_MkJ3WC}(y9)O{ZI%W7(#kl?V_NSKD=p?mJtFarn-~o^!8XBsx%0%ur`BN z#^d1ybIYzO7-)sHjX+g9{jdLCe8pMdU{{mvO`pYp-6@ z)2ETupSZ~tQivhI6dS7d9nkn&X97tnpe+oNL<0>M3vO;qL!DRNh5a=|iPi6Za1mn! z9XtTzJdLpVw7h^D3Jua1 z;DAByWMF`D(uc*1_L4WQmJl@MZ97DTZPvsF|DRLTibA3C>~!O4Qf@quJAw8x@7zZb zg;7oiiF{C>5Yz|6#-B)^*bfEGLsttGs#txam$YBTzZ=kpciw?n&d<38lXO<99ck^$eN&f>a$qBL0NLB>`kxBe{G9q`aBF7=qY9-Z2n35AfzY`D!;g| zQQi_J`n1E<_3PE-shhUP359k99JCboC=||0(m)(sFL6f`zIdPug3;wE16*|m50Ooo zL-GO2s(_Hg;Nkj-nQ~j390OAiP~Kg(7F_)JY>w0rc~OLUeUfdNnJr9O&7{Jpw)r`B`RXF5~^*;$MI=M->DyjAoGT~!;|A)mf~j%p?? zyj!0@5Hj{Sx8>gJi9KstzYYa#Dda$7K=nt-(_LA`U? zWLxgM&jP71hVJDlxb>C)Y%){EWywhw@!^jf5kqAWiaujYnfbPOyEk#)^c=-`zAaEe zD%Z93Q{`%;eJA#kfZ`OpHUZ>|1W??g+?Jj8_%E{UjmvW%tV1^)sG$6KaD9TzwsN0z zPt!&+TerMV9KEg^cShPznw(?s6{M3ct+t1E=j|0v7)%7(b5d|GDXYbqgA4;sW* zOh3040tM)Aq$HosVfccO9_&i8g}-pTv=3xM4uMbHT;s2DV(}geg#@YX)PM#Fmc@j( zoobfy6V{DKhV{NqkLbn-@b3uQ#Er^rrL5gXY8+J1_9(;lZ8cHhRi0G5P*5aublC*MePRZJF!hR&D<>fy#uIDx!`>n9hB_ksY*!o^GIoVIu znO6{4yG|Jp`D;j@pN-4!4*i~HjX8>}0W$UNKbybHE`|M$F1;P$^a}6r0Ob_$5@=#O z*EG|=n4IPUSt8&QuF!fwm>|y8w(?Sz9tc$N*>%Imo1&ev7A}6jl|T38WO{rh_C>4! zX{@Vs%hK+Hz*2w&7XtzVZu>&^RC?VzH(CGyK}7&lm$CPJMnhGx^RJzwno2IsyM_E} zzp&NtZ&Ia>PXtOmrQ58+NuU2JtjJz40QC6GJ&u%N$G4k5r%Q%EGV8fanrixk^o8ZP z4qcH?prn|4)@`h+3|}Cma80!eH^|~dx670-R}2vUY%)a*PB(wZN0o>d)tTgF@7Mbb zU+zr}m<<*|VamB@-?fEak1{3ZDX*)6t)q&h!!F>Y(RHP{^^cr(uf$Y!oMmLD)-o=y z^sT6#6;n+r0E+n1ME%VADg|P5s3hAA9mu%J^y9E&7Abo%XT~YlBwbA`Nnc;q@~hlo zxu>a;aFs(QpR-KVB+s$(aG$GR!TI}U7!|h3jA%vn8fql)PdMZG?r(r^3Y!~N<#D#e z_&WyZjO_6`0y1|lNm8D-krSw6VP9vX0j)LlcD2i~nu^J**w=n|>@%2sm`d*1vxm-` zpOE4Xp7Ft|mO=VeQYNW_PT!W2Uikc1-(FmS+f0)2@8^vh_pEkN|Dekb^ed&6f(3p< zRznWgoxH7$`e>Y+x0E+7P$A0F*oXGj-9qndj*2`+<<;QE=}2O&CR3`p_vP`XD~XD8d) zL{Dqb!WZ}l;fb&6FqLaWVT#Aim&d-<-u55;qx^1Z<-WZk49NHq@C4y>GG8upF2V!I z!HBLfmh+73r%TtXGQ*95#KVs5SK6JBFhJn^{?Gml3y8v{yF{V_ncelW|6@6UhgP#b zb25y?!P2aMT$AX*BT&v|vH=)n$Awp7Ki#cKE28PK&`{98VfWFS_9N!)O<^Ua!Raxu zWCVaF!6>90Zv%M(NK@jXBWOWTLi9#a^4+Rhj(>q_qAMSO%!QrB((s4?gR5V9*X&Z! zM93$89G$KBAAcsbL39LcB^pT8aJQJv0xTr6eyuoxb3|2=2+L}0Neu)wE?WZIsI&1c zq1r>1i*GVA?lu5We%R4c;%&gTIWjyHewDl}PIYa1imlc#peYxxQBqQILQ#jCtBOLSW-!>h9-38Vm52VaMI zs$uh2>+4cC0Ize4%Jg(cUHl5Yy~PwSa-BB;yR!bKy{^?_WXGnN7+@_dVcyd^I#--m z(PZrc0|y5nhGNB~fxusdMNUi_Kk^M=tGc^VT&1{5JESEF?>8r~Zwc&>{kb&C9vdcq zL`hBzUW$AKKu|}vNjV|y>%O<43zMePROa6?!{XAmF%REd0Up+~{2Hh-xlbLzN>8nS zI!E+=l520vHMJPrY9At$K;d!3DfUJK2`zXO@hNdoyHujTN_t1{zDCFP_sxI-QUB8N zU*6IopGjbGP2hl|8qwGyI83UjCY}W}3$n6*t5?W&^5d!qK^%BPk#G0ZMJ-2Oq)?;i z+cyN`m**Y@s*QB8g|dphg=E9GdtVf%s4)z1aMqr7E*30_^Huq#*+`Co2P7TfjfRVm zJpkg+TOc8T=?|qfRUWfujl#WX9iu<{4etX3&Nk+`221BFxY&qMHpMHdP8Hrw0gZfe z+PL>xYgK7Gp#OB}N*DzP!YGD#0`o_?UE=>IzCdb%%kQ=aXIjE>MHSFi%)XleZ&|R| z;XLMS#*hta%t1H+XOPcgx8>?+(`N<=h0R9+^u1ry3bgT?p9BbtCh-8n`$HN$UBjhJ zHwn7NJ2<&(j zE^3uGXG8)3c5-#@Cw{ecOMkwba?V?3Ta25-*9KT?QL3obk<%8-n%z z=->raKh7l}uHi4`u`88k)wNz6PYbX7{Ptb?&lE?Flq4&SMu91nqlT)=#Jw{om&qiT z-Lw1uv6NOHq$<>P==-ga`!w!crFF6vIx_nuvitYVW{;gV1C=-5VE3dw+U+@ z2jgdcQ7mBW1fF4l*QQZ*^C=Y@SE&$0p2aEWHt?C%vxOQIB2U5PvVdp98e85ix4`fB zvVOb&WpByP=dkZ<9@=#Tt*=O@xSH14X)dJ{B-tXAbjQ}xufE>2Sy+B|QLVRll#oD? zgmW>}z=xo)pGK$1%I#^FB+A$4>4`wMTR&Wx2a4a3GSS=7MmSt9x8K0D7*=k9N517x zHl(iJ^sX8W8?<=0jUq4U&Ee7mjqyx5aQ-$u&!6cEJxQO}cko5fVp+1G&W;|#j7 zu_1p|u=leNJqH}r<`U`++L(^t4nCe=u3_F-Nv&+CZ_Kt-_jx)Urz1oK>i4M(j2mRg z%4EyTy0QL|$5t{)JT&}1$t>hi0Gg+#;Et~x!H z#Hd>P6oTBJn`9Pf!|DvUp{Vb~5X@dQ?l?>?r^?F0{qA$4V~CsmIv-~2xg%4}r- literal 0 HcmV?d00001 diff --git a/media/llama1-icon.svg b/media/llama1-icon.svg new file mode 100644 index 0000000000..dcbe9cce9b --- /dev/null +++ b/media/llama1-icon.svg @@ -0,0 +1,87 @@ + + + + + + + + + + + + + + + + + + From 7057faf64b514e991e2f70147f82bb13d544b1c0 Mon Sep 17 00:00:00 2001 From: Aldehir Rojas Date: Mon, 8 Sep 2025 16:14:32 -0500 Subject: [PATCH 24/46] json : support `enum` values within `allOf` (#15830) --- common/json-schema-to-grammar.cpp | 22 ++++++++- examples/json_schema_to_grammar.py | 15 ++++++- tests/test-json-schema-to-grammar.cpp | 45 +++++++++++++++++++ .../public_legacy/json-schema-to-grammar.mjs | 15 ++++++- 4 files changed, 94 insertions(+), 3 deletions(-) diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 637891f506..182c787544 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -843,9 +843,10 @@ public: _build_object_rule( properties, required, name, schema.contains("additionalProperties") ? schema["additionalProperties"] : json())); - } else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) { + } else if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) { std::unordered_set required; std::vector> properties; + std::map enum_values; std::string hybrid_name = name; std::function add_component = [&](const json & comp_schema, bool is_required) { if (comp_schema.contains("$ref")) { @@ -857,6 +858,14 @@ public: required.insert(prop.key()); } } + } else if (comp_schema.contains("enum")) { + for (const auto & v : comp_schema["enum"]) { + const auto rule = _generate_constant_rule(v); + if (enum_values.find(rule) == enum_values.end()) { + enum_values[rule] = 0; + } + enum_values[rule] += 1; + } } else { // todo warning } @@ -870,6 +879,17 @@ public: add_component(t, true); } } + if (!enum_values.empty()) { + std::vector enum_intersection; + for (const auto & p : enum_values) { + if (p.second == schema["allOf"].size()) { + enum_intersection.push_back(p.first); + } + } + if (!enum_intersection.empty()) { + return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space"); + } + } return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json())); } else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) { json items = schema.contains("items") ? schema["items"] : schema["prefixItems"]; diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index ed37958554..2d57549046 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -586,9 +586,10 @@ class SchemaConverter: properties = list(schema.get('properties', {}).items()) return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties'))) - elif schema_type in (None, 'object') and 'allOf' in schema: + elif schema_type in (None, 'object', 'string') and 'allOf' in schema: required = set() properties = [] + enum_sets = [] hybrid_name = name def add_component(comp_schema, is_required): if (ref := comp_schema.get('$ref')) is not None: @@ -600,6 +601,9 @@ class SchemaConverter: if is_required: required.add(prop_name) + if 'enum' in comp_schema: + enum_sets.append(set(comp_schema['enum'])) + for t in schema['allOf']: if 'anyOf' in t: for tt in t['anyOf']: @@ -607,6 +611,15 @@ class SchemaConverter: else: add_component(t, is_required=True) + if enum_sets: + enum_intersection = enum_sets[0] + for s in enum_sets[1:]: + enum_intersection &= s + + if enum_intersection: + rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space' + return self._add_rule(rule_name, rule) + return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None)) elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 78ee55e246..67df240c6f 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -1209,6 +1209,51 @@ static void test_all(const std::string & lang, std::function { const ref = compSchema.$ref; if (ref !== undefined) { @@ -648,6 +649,10 @@ export class SchemaConverter { } } } + + if ('enum' in compSchema) { + enumSets.push(new Set(compSchema.enum || [])); + } }; for (const t of schema.allOf) { @@ -660,6 +665,14 @@ export class SchemaConverter { } } + if (enumSets.length > 0) { + const enumIntersection = new Set([...enumSets[0]].filter(v => enumSets.every(s => s.has(v)))); + if (enumIntersection.size > 0) { + const sortedEnums = [...enumIntersection].sort((a, b) => a.localeCompare(b)); + const rule = '(' + sortedEnums.map(v => this._generateConstantRule(v)).join(' | ') + ') space'; + return this._addRule(ruleName, rule); + } + } return this._addRule(ruleName, this._buildObjectRule(properties, required, name, null)); } else if ((schemaType === undefined || schemaType === 'array') && ('items' in schema || 'prefixItems' in schema)) { const items = schema.items ?? schema.prefixItems; From acc1b008cfd95e63d5f99a370dbffb98e5a99d2c Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Tue, 9 Sep 2025 06:05:55 +0200 Subject: [PATCH 25/46] model-conversion : add extra debugging support for model conversion (#15877) * feat: Extra debugging support for model conversion - added BF16 support for llama-callback-eval and support for dumping intermediate steps in run-org-model.py --- examples/eval-callback/eval-callback.cpp | 11 ++ examples/model-conversion/requirements.txt | 9 +- .../scripts/causal/run-org-model.py | 149 ++++++++++++++++-- 3 files changed, 156 insertions(+), 13 deletions(-) diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index d4ef751fbb..cefa39a57c 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -28,6 +28,15 @@ static std::string ggml_ne_string(const ggml_tensor * t) { return str; } +static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) { + union { + float f; + uint32_t i; + } u; + u.i = (uint32_t)h.bits << 16; + return u.f; +} + static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) { size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; float v; @@ -43,6 +52,8 @@ static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * v = (float) *(int16_t *) &data[i]; } else if (type == GGML_TYPE_I8) { v = (float) *(int8_t *) &data[i]; + } else if (type == GGML_TYPE_BF16) { + v = ggml_compute_bf16_to_fp32(*(ggml_bf16_t *) &data[i]); } else { GGML_ABORT("fatal error"); } diff --git a/examples/model-conversion/requirements.txt b/examples/model-conversion/requirements.txt index b8148b269a..ac9f69e10b 100644 --- a/examples/model-conversion/requirements.txt +++ b/examples/model-conversion/requirements.txt @@ -1,5 +1,6 @@ --extra-index-url https://download.pytorch.org/whl/cpu -torch~=2.6.0 -torchvision~=0.21.0 -transformers~=4.55.0 -huggingface-hub~=0.34.0 +torch +torchvision +transformers +huggingface-hub +accelerate diff --git a/examples/model-conversion/scripts/causal/run-org-model.py b/examples/model-conversion/scripts/causal/run-org-model.py index f6188ea6f3..78a54abf13 100755 --- a/examples/model-conversion/scripts/causal/run-org-model.py +++ b/examples/model-conversion/scripts/causal/run-org-model.py @@ -9,15 +9,134 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig import torch import numpy as np -unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME') +### If you want to dump RoPE activations, apply this monkey patch to the model +### class from Transformers that you are running (replace apertus.modeling_apertus +### with the proper package and class for your model +### === START ROPE DEBUG === +# from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb -parser = argparse.ArgumentParser(description='Process model with specified path') -parser.add_argument('--model-path', '-m', help='Path to the model') +# orig_rope = apply_rotary_pos_emb +# torch.set_printoptions(threshold=float('inf')) +# torch.set_printoptions(precision=6, sci_mode=False) + +# def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +# # log inputs +# summarize(q, "RoPE.q_in") +# summarize(k, "RoPE.k_in") + +# # call original +# q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim) + +# # log outputs +# summarize(q_out, "RoPE.q_out") +# summarize(k_out, "RoPE.k_out") + +# return q_out, k_out + +# # Patch it +# import transformers.models.apertus.modeling_apertus as apertus_mod # noqa: E402 +# apertus_mod.apply_rotary_pos_emb = debug_rope +### == END ROPE DEBUG === + + +def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3): + """ + Print a tensor in llama.cpp debug style. + + Supports: + - 2D tensors (seq, hidden) + - 3D tensors (batch, seq, hidden) + - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head + + Shows first and last max_vals of each vector per sequence position. + """ + t = tensor.detach().to(torch.float32).cpu() + + # Determine dimensions + if t.ndim == 3: + _, s, _ = t.shape + elif t.ndim == 2: + _, s = 1, t.shape[0] + t = t.unsqueeze(0) + elif t.ndim == 4: + _, s, _, _ = t.shape + else: + print(f"Skipping tensor due to unsupported dimensions: {t.ndim}") + return + + ten_shape = t.shape + + print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}") + print(" [") + print(" [") + + # Determine indices for first and last sequences + first_indices = list(range(min(s, max_seq))) + last_indices = list(range(max(0, s - max_seq), s)) + + # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq + has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s) + + # Combine indices + if has_overlap: + # If there's overlap, just use the combined unique indices + indices = sorted(list(set(first_indices + last_indices))) + separator_index = None + else: + # If no overlap, we'll add a separator between first and last sequences + indices = first_indices + last_indices + separator_index = len(first_indices) + + for i, si in enumerate(indices): + # Add separator if needed + if separator_index is not None and i == separator_index: + print(" ...") + + # Extract appropriate slice + vec = t[0, si] + if vec.ndim == 2: # 4D case: flatten heads × dim_per_head + flat = vec.flatten().tolist() + else: # 2D or 3D case + flat = vec.tolist() + + # First and last slices + first = flat[:max_vals] + last = flat[-max_vals:] if len(flat) >= max_vals else flat + first_str = ", ".join(f"{v:12.4f}" for v in first) + last_str = ", ".join(f"{v:12.4f}" for v in last) + + print(f" [{first_str}, ..., {last_str}]") + + print(" ],") + print(" ]") + print(f" sum = {t.sum().item():.6f}\n") + + +def debug_hook(name): + def fn(_m, input, output): + if isinstance(input, torch.Tensor): + summarize(input, name + "_in") + elif isinstance(input, (tuple, list)) and isinstance(input[0], torch.Tensor): + summarize(input[0], name + "_in") + if isinstance(output, torch.Tensor): + summarize(output, name + "_out") + elif isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor): + summarize(output[0], name + "_out") + + return fn + + +unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME") + +parser = argparse.ArgumentParser(description="Process model with specified path") +parser.add_argument("--model-path", "-m", help="Path to the model") args = parser.parse_args() -model_path = os.environ.get('MODEL_PATH', args.model_path) +model_path = os.environ.get("MODEL_PATH", args.model_path) if model_path is None: - parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable") + parser.error( + "Model path must be specified either via --model-path argument or MODEL_PATH environment variable" + ) config = AutoConfig.from_pretrained(model_path) @@ -34,18 +153,30 @@ config = AutoConfig.from_pretrained(model_path) if unreleased_model_name: model_name_lower = unreleased_model_name.lower() - unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}" + unreleased_module_path = ( + f"transformers.models.{model_name_lower}.modular_{model_name_lower}" + ) class_name = f"{unreleased_model_name}ForCausalLM" print(f"Importing unreleased model module: {unreleased_module_path}") try: - model_class = getattr(importlib.import_module(unreleased_module_path), class_name) - model = model_class.from_pretrained(model_path) # Note: from_pretrained, not fromPretrained + model_class = getattr( + importlib.import_module(unreleased_module_path), class_name + ) + model = model_class.from_pretrained( + model_path + ) # Note: from_pretrained, not fromPretrained except (ImportError, AttributeError) as e: print(f"Failed to import or load model: {e}") exit(1) else: - model = AutoModelForCausalLM.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained( + model_path, device_map="auto", offload_folder="offload" + ) + +for name, module in model.named_modules(): + if len(list(module.children())) == 0: # only leaf modules + module.register_forward_hook(debug_hook(name)) model_name = os.path.basename(model_path) # Printing the Model class to allow for easier debugging. This can be useful From 70cd37dbbebdb7a2f84f08207f18eabb0b291a55 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 9 Sep 2025 06:06:52 +0200 Subject: [PATCH 26/46] requirements : update transformers/torch for Embedding Gemma (#15828) * requirements : update transformers/torch for Embedding Gemma This commit updates the requirements to support converting Embedding Gemma 300m models. The motivation for this change is that during development I had a local copy of the transformers package which is what I used for converting the models. This was a mistake on my part and I should have also updated my transformers version to the official release. I had checked the requirements/requirements-convert_legacy_llama.txt file and noted that the version was >=4.45.1,<5.0.0 and came to the conculusion that no updated would be needed, this assumed that Embedding Gemma would be in a transformers release at the time Commit fb15d649ed14ab447eeab911e0c9d21e35fb243e ("llama : add support for EmbeddingGemma 300m (#15798)) was merged. So anyone wanting to convert themselves would be able to do so. However, Embedding Gemma is a preview release and this commit updates the requirements to use this preview release. * resolve additional python dependencies * fix pyright errors in tokenizer test and remove unused import --- requirements/requirements-convert_hf_to_gguf.txt | 4 +++- requirements/requirements-convert_legacy_llama.txt | 11 ++++++++++- requirements/requirements-tool_bench.txt | 2 +- tests/test-tokenizer-random.py | 4 ++-- .../minicpmv-convert-image-encoder-to-gguf.py | 4 ++-- tools/mtmd/requirements.txt | 4 ++-- tools/server/tests/requirements.txt | 2 +- 7 files changed, 21 insertions(+), 10 deletions(-) diff --git a/requirements/requirements-convert_hf_to_gguf.txt b/requirements/requirements-convert_hf_to_gguf.txt index 766745f42f..90c98c3ffe 100644 --- a/requirements/requirements-convert_hf_to_gguf.txt +++ b/requirements/requirements-convert_hf_to_gguf.txt @@ -2,7 +2,9 @@ mistral-common>=1.8.3 -r ./requirements-convert_legacy_llama.txt --extra-index-url https://download.pytorch.org/whl/cpu -torch~=2.4.0; platform_machine != "s390x" + +## Embedding Gemma requires PyTorch 2.6.0 or later +torch~=2.6.0; platform_machine != "s390x" # torch s390x packages can only be found from nightly builds --extra-index-url https://download.pytorch.org/whl/nightly diff --git a/requirements/requirements-convert_legacy_llama.txt b/requirements/requirements-convert_legacy_llama.txt index 859204b27e..f6076142ce 100644 --- a/requirements/requirements-convert_legacy_llama.txt +++ b/requirements/requirements-convert_legacy_llama.txt @@ -1,5 +1,14 @@ numpy~=1.26.4 sentencepiece~=0.2.0 -transformers>=4.45.1,<5.0.0 + +# Embedding Gemma is currently a preview release: +# https://github.com/huggingface/transformers/releases/tag/v4.56.0-Embedding-Gemma-preview + +# The version is needed to be able to convert Embedding Gemma models to GGUF format: +git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview + +# Once Embedding Gemma is officially released, we can switch to: +#transformers>=4.57.1,<5.0.0 + gguf>=0.1.0 protobuf>=4.21.0,<5.0.0 diff --git a/requirements/requirements-tool_bench.txt b/requirements/requirements-tool_bench.txt index b94521fc7f..f7912aff72 100644 --- a/requirements/requirements-tool_bench.txt +++ b/requirements/requirements-tool_bench.txt @@ -1,6 +1,6 @@ aiohttp~=3.9.3 pytest~=8.3.3 -huggingface_hub~=0.23.2 +huggingface_hub>=0.34.0,<1.0 matplotlib~=3.10.0 numpy~=1.26.4 openai~=1.55.3 diff --git a/tests/test-tokenizer-random.py b/tests/test-tokenizer-random.py index c6cdcb5548..93e697607e 100644 --- a/tests/test-tokenizer-random.py +++ b/tests/test-tokenizer-random.py @@ -421,10 +421,10 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl if text1 == text2: # equal to TokenizerGroundtruth? return True # equal to source text? - if tokenizer1.add_bos_token: # remove BOS + if tokenizer1.add_bos_token and tokenizer1.bos_token and isinstance(tokenizer1.bos_token, str): # remove BOS if text2.startswith(tokenizer1.bos_token): text2 = text2[len(tokenizer1.bos_token):] - if tokenizer1.add_eos_token: # remove EOS + if tokenizer1.add_eos_token and tokenizer1.eos_token and isinstance(tokenizer1.eos_token, str): # remove EOS if text2.endswith(tokenizer1.eos_token): text2 = text2[:-len(tokenizer1.eos_token)] return text == text2 diff --git a/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py b/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py index f34d858d67..bb2cc4e4ea 100644 --- a/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py +++ b/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py @@ -23,7 +23,6 @@ import warnings import numpy as np import torch import torch.nn.functional as F -import torch.utils.checkpoint from torch import nn from torch.nn.init import _calculate_fan_in_and_fan_out @@ -413,7 +412,8 @@ import re import numpy as np from gguf import * -from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer, Idefics2VisionConfig +from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer +from transformers.models.idefics2.configuration_idefics2 import Idefics2VisionConfig TEXT = "clip.text" VISION = "clip.vision" diff --git a/tools/mtmd/requirements.txt b/tools/mtmd/requirements.txt index a9d788f265..0a1f4e8647 100644 --- a/tools/mtmd/requirements.txt +++ b/tools/mtmd/requirements.txt @@ -1,5 +1,5 @@ -r ../../requirements/requirements-convert_legacy_llama.txt --extra-index-url https://download.pytorch.org/whl/cpu pillow~=11.3.0 -torch~=2.4.0 -torchvision~=0.19.1 +torch~=2.6.0 +torchvision~=0.21.0 diff --git a/tools/server/tests/requirements.txt b/tools/server/tests/requirements.txt index 15d024914e..4ea7f19f77 100644 --- a/tools/server/tests/requirements.txt +++ b/tools/server/tests/requirements.txt @@ -1,6 +1,6 @@ aiohttp~=3.9.3 pytest~=8.3.3 -huggingface_hub~=0.23.2 +huggingface_hub>=0.34.0,<1.0 numpy~=1.26.4 openai~=1.55.3 prometheus-client~=0.20.0 From c252ce67c4b99f056eeb18d42a289e28fee03475 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 9 Sep 2025 08:42:10 +0300 Subject: [PATCH 27/46] contrib : add notes about merging PRs (#15881) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * contrib : add notes about merging PRs * Update CONTRIBUTING.md Co-authored-by: Diego Devesa * Update CONTRIBUTING.md Co-authored-by: Johannes Gäßler --------- Co-authored-by: Diego Devesa Co-authored-by: Johannes Gäßler --- CONTRIBUTING.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e68ff92445..6db74a0cf2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -16,6 +16,9 @@ - Use the following format for the squashed commit title: ` : (#)`. For example: `utils : fix typo in utils.py (#1234)` - Optionally pick a `` from here: https://github.com/ggml-org/llama.cpp/wiki/Modules - Consider adding yourself to [CODEOWNERS](CODEOWNERS) +- Let authors, who are also collaborators, merge their own PRs +- When merging a PR by a contributor, make sure you have a good understanding of the changes +- Be mindful of maintenance: most of the work going into a feature happens after the PR is merged. If the PR author is not committed to contribute long-term, someone else needs to take responsibility (you) # Coding guidelines From 550cf726e133fd0a069d991287fd3a2a3e3e1cbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 9 Sep 2025 08:11:01 +0200 Subject: [PATCH 28/46] CUDA: fix GET_ROWS for large tensors (#15882) --- ggml/src/ggml-cuda/getrows.cu | 80 +++++++++++++++++---------------- ggml/src/ggml-cuda/ggml-cuda.cu | 4 -- 2 files changed, 41 insertions(+), 43 deletions(-) diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 83d02474f5..2fab33243d 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -2,39 +2,39 @@ #include "dequantize.cuh" #include "convert.cuh" -#define MAX_GRIDDIM_Y 65535 - template static __global__ void k_get_rows( const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ - /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ + /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) { - // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. - const int i10 = blockIdx.x; - const int i11 = blockIdx.z / ne12; - const int i12 = blockIdx.z % ne12; + for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) { + for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) { + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i10 = blockIdx.x; + const int i11 = z / ne12; // TODO fastdiv + const int i12 = z % ne12; - const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; - dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03; + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03; - const int ib = i00/qk; // block index - const int iqs = (i00%qk)/qr; // quant index - const int iybs = i00 - i00%qk; // dst block start index - const int y_offset = qr == 1 ? 1 : qk/2; + const int ib = i00/qk; // block index + const int iqs = (i00%qk)/qr; // quant index + const int iybs = i00 - i00%qk; // dst block start index + const int y_offset = qr == 1 ? 1 : qk/2; - // dequantize - float2 v; - dequantize_kernel(src0_row, ib, iqs, v); + // dequantize + float2 v; + dequantize_kernel(src0_row, ib, iqs, v); - dst_row[iybs + iqs + 0] = ggml_cuda_cast(v.x); - dst_row[iybs + iqs + y_offset] = ggml_cuda_cast(v.y); + dst_row[iybs + iqs + 0] = ggml_cuda_cast(v.x); + dst_row[iybs + iqs + y_offset] = ggml_cuda_cast(v.y); + } } } @@ -42,27 +42,29 @@ template static __global__ void k_get_rows_float( const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ - /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ + /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) { - // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. - const int i10 = blockIdx.x; - const int i11 = blockIdx.z / ne12; - const int i12 = blockIdx.z % ne12; + for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) { + for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) { + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i10 = blockIdx.x; + const int i11 = z / ne12; // TODO fastdiv + const int i12 = z % ne12; - if (i00 >= ne00) { - return; + if (i00 >= ne00) { + return; + } + + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03); + + dst_row[i00] = ggml_cuda_cast(src0_row[i00]); } - - const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; - - dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03); - - dst_row[i00] = ggml_cuda_cast(src0_row[i00]); } } @@ -98,7 +100,7 @@ static void get_rows_cuda_q( cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); - const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12); + const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX)); // strides in elements // const size_t s0 = nb0 / sizeof(dst_t); @@ -116,7 +118,7 @@ static void get_rows_cuda_q( k_get_rows<<>>( src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ - /*ne10, ne11,*/ ne12, /*ne13,*/ + /*ne10,*/ ne11, ne12, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, s10, s11, s12/*, s13*/); @@ -131,7 +133,7 @@ static void get_rows_cuda_float( cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; - const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12); + const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX)); // strides in elements // const size_t s0 = nb0 / sizeof(dst_t); @@ -147,7 +149,7 @@ static void get_rows_cuda_float( k_get_rows_float<<>>( src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ - /*ne10, ne11,*/ ne12, /*ne13,*/ + /*ne10,*/ ne11, ne12, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, s10, s11, s12/*, s13*/); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index efca2c775a..95170ae11b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3393,10 +3393,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_GET_ROWS: { - // FIXME: https://github.com/ggml-org/llama.cpp/pull/15868 - if (op->src[1]->ne[1]*op->src[1]->ne[2] > 65535) { - return false; - } switch (op->src[0]->type) { case GGML_TYPE_F16: case GGML_TYPE_F32: From a972faebed5fdc4a3d2a844d92d476058c02e02d Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 9 Sep 2025 14:38:02 +0800 Subject: [PATCH 29/46] CUDA: Add mul_mat_id support for the mmf kernel (#15767) * CUDA: Add mul_mat_id support the mmf Add support for mul_mat_id for bs < 16 * Review: use warp_size, fix should_use_mmf condition * Launch one block per expert, stride along n_expert_used * templatize mul_mat_id * Pad shmem to 16 bytes, add helper function mul_mat_f_switch_ids * Reduce compile times by dividing mmf into f16, bf16 and f32 variants * Divide mmf by ncols_dst * Add missing files * Fix MUSA/HIP builds --- ggml/src/ggml-cuda/CMakeLists.txt | 2 + ggml/src/ggml-cuda/ggml-cuda.cu | 5 + ggml/src/ggml-cuda/mma.cuh | 1 + ggml/src/ggml-cuda/mmf.cu | 382 ++------------ ggml/src/ggml-cuda/mmf.cuh | 470 +++++++++++++++++- .../template-instances/generate_cu_files.py | 11 + .../mmf-instance-ncols_1.cu | 5 + .../mmf-instance-ncols_10.cu | 5 + .../mmf-instance-ncols_11.cu | 5 + .../mmf-instance-ncols_12.cu | 5 + .../mmf-instance-ncols_13.cu | 5 + .../mmf-instance-ncols_14.cu | 5 + .../mmf-instance-ncols_15.cu | 5 + .../mmf-instance-ncols_16.cu | 5 + .../mmf-instance-ncols_2.cu | 5 + .../mmf-instance-ncols_3.cu | 5 + .../mmf-instance-ncols_4.cu | 5 + .../mmf-instance-ncols_5.cu | 5 + .../mmf-instance-ncols_6.cu | 5 + .../mmf-instance-ncols_7.cu | 5 + .../mmf-instance-ncols_8.cu | 5 + .../mmf-instance-ncols_9.cu | 5 + tests/test-backend-ops.cpp | 2 +- 23 files changed, 603 insertions(+), 350 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index d3dfc7807d..0d8c5af473 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/mmq*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "template-instances/mmf*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) if (GGML_CUDA_FA_ALL_QUANTS) file(GLOB SRCS "template-instances/fattn-vec*.cu") diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 95170ae11b..0f68d68534 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2109,6 +2109,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst); return; } + + if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2])) { + ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst); + return; + } } cudaStream_t stream = ctx.stream(); diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 667deb9c65..c1f24243fe 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -1,3 +1,4 @@ +#pragma once // This file contains primitives that expose the tensor core PTX instructions for CUDA code. // The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout. // The documentation for the PTX instructions can be found under: diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index cfa5c5cce2..16331e9ecf 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -1,343 +1,12 @@ #include "ggml.h" -#include "common.cuh" -#include "mma.cuh" #include "mmf.cuh" -using namespace ggml_cuda_mma; - -#define MMF_ROWS_PER_BLOCK 32 - -template -__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) -static __global__ void mul_mat_f( - const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, - const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst, - const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - typedef tile<16, 8, T> tile_A; - typedef tile< 8, 8, T> tile_B; - typedef tile<16, 8, float> tile_C; - - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - constexpr int tile_k_padded = warp_size + 4; - constexpr int ntA = rows_per_block / tile_A::I; - constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; - - const int row0 = blockIdx.x * rows_per_block; - const int channel_dst = blockIdx.y; - const int channel_x = channel_dst / channel_ratio; - const int channel_y = channel_dst; - const int sample_dst = blockIdx.z; - const int sample_x = sample_dst / sample_ratio; - const int sample_y = sample_dst; - - x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ; - y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; - dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; - - const float2 * y2 = (const float2 *) y; - - extern __shared__ char data_mmv[]; - - tile_C C[ntA][ntB]; - - T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded); - - for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { - tile_A A[ntA][warp_size / tile_A::J]; -#pragma unroll - for (int itA = 0; itA < ntA; ++itA) { -#pragma unroll - for (int i = 0; i < tile_A::I; ++i) { - tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col]; - } -#pragma unroll - for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { - load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); - } - } - -#pragma unroll - for (int itB = 0; itB < ntB; ++itB) { - if constexpr (std::is_same_v) { -#pragma unroll - for (int j0 = 0; j0 < tile_B::I; ++j0) { - const int j = j0 + itB*tile_B::I; - - tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f; - } - } else if constexpr (std::is_same_v || std::is_same_v) { -#pragma unroll - for (int j0 = 0; j0 < tile_B::I; ++j0) { - const int j = j0 + itB*tile_B::I; - - const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); - tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; - } - } else { - static_assert(std::is_same_v, "unsupported type"); - } -#pragma unroll - for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { - tile_B B; - load_ldmatrix(B, tile_xy + k0, tile_k_padded); -#pragma unroll - for (int itA = 0; itA < ntA; ++itA) { - mma(C[itA][itB], A[itA][k0/tile_B::J], B); - } - } - } - } - - float * buf_iw = (float *) data_mmv; - constexpr int kiw = nwarps*rows_per_block + 4; - - if (nwarps > 1) { - __syncthreads(); - } -#pragma unroll - for (int itB = 0; itB < ntB; ++itB) { -#pragma unroll - for (int itA = 0; itA < ntA; ++itA) { -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); - const int j = itB*tile_C::J + tile_C::get_j(l); - buf_iw[j*kiw + i] = C[itA][itB].x[l]; - } - } - } - - if (nwarps > 1) { - __syncthreads(); - } - -#pragma unroll - for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - if (j0 + nwarps > cols_per_block && j >= cols_per_block) { - return; - } - - float sum = 0.0f; - static_assert(rows_per_block == warp_size, "need loop/check"); -#pragma unroll - for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { - const int i = i0 + threadIdx.x; - - sum += buf_iw[j*kiw + i]; - } - dst[j*stride_col_dst + row0 + threadIdx.x] = sum; - } -#else - GGML_UNUSED_VARS(x, y, ids, dst, - ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - NO_DEVICE_CODE; -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) -} - -template -static void mul_mat_f_cuda( - const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols_x, const int64_t nrows_x, - const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, - const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { - typedef tile<16, 8, T> tile_A; - typedef tile< 8, 8, T> tile_B; - - GGML_ASSERT(!ids && "mul_mat_id not implemented"); - - GGML_ASSERT(ncols_x % 2 == 0); - GGML_ASSERT(stride_row % 2 == 0); - GGML_ASSERT(stride_col_y % 2 == 0); - GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); - GGML_ASSERT( nsamples_dst % nsamples_x == 0); - const int64_t channel_ratio = nchannels_dst / nchannels_x; - const int64_t sample_ratio = nsamples_dst / nsamples_x; - - const int device = ggml_cuda_get_device(); - const int warp_size = ggml_cuda_info().devices[device].warp_size; - - int64_t nwarps_best = 1; - int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2); - int64_t max_block_size = 256; - for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) { - const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2); - if (niter < niter_best) { - niter_best = niter; - nwarps_best = nwarps; - } - } - - constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; - const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4; - const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; - const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); - const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst); - const dim3 block_dims(warp_size, nwarps_best, 1); - switch (nwarps_best) { - case 1: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 2: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 3: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 4: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 5: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 6: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 7: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 8: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - default: { - GGML_ABORT("fatal error"); - } break; - } -} - -template -static void mul_mat_f_switch_cols_per_block( - const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, - const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, - const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { - switch (ncols_dst) { - case 1: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 2: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 3: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 4: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 5: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 6: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 7: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 8: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 9: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 10: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 11: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 12: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 13: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 14: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 15: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 16: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - default: { - GGML_ABORT("fatal error"); - } break; - } -} - void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_TENSOR_BINARY_OP_LOCALS; const size_t ts_src0 = ggml_type_size(src0->type); @@ -365,55 +34,72 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr const int64_t s13 = src1->nb[3] / ts_src1; const int64_t s3 = dst->nb[3] / ts_dst; + const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0; + const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: const int64_t ncols_dst = ids ? ne2 : ne1; - const int64_t nchannels_y = ids ? ne11 : ne12; - const int64_t nchannels_dst = ids ? ne1 : ne2; - const int64_t stride_channel_dst = ids ? s1 : s2; - const int64_t stride_channel_y = ids ? s11 : s12; + const int64_t nchannels_dst = ids ? ne1 : ne2; - GGML_ASSERT(!ids || ncols_dst == 1); + const int64_t stride_col_dst = ids ? s2 : s1; + const int64_t stride_col_y = ids ? s12 : s11; + const int64_t stride_channel_dst = ids ? s1 : s2; + + int64_t stride_channel_y = ids ? s11 : s12; + int64_t nchannels_y = ids ? ne11 : ne12; + + //mul_mat_id: handle broadcast + if (ids && nchannels_y == 1) { + stride_channel_y = 0; + nchannels_y = ids->ne[0]; + } switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; constexpr int vals_per_T = 1; mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, - ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); } break; case GGML_TYPE_F16: { const half2 * src0_d = (const half2 *) src0->data; constexpr int vals_per_T = 2; mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, - ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); } break; case GGML_TYPE_BF16: { const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; constexpr int vals_per_T = 2; mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, - ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); } } -bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) { +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols) { + + if (ggml_is_quantized(type)) { + return false; + } + if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) { return false; } if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) { return false; } - if (ne11 > 16) { + if (src1_ncols > 16) { return false; } + switch (type) { case GGML_TYPE_F32: return ampere_mma_available(cc); diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index 785f9f211c..bf724bc57b 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -1,5 +1,473 @@ +#pragma once + +#include "mma.cuh" #include "common.cuh" +using namespace ggml_cuda_mma; + +#define MMF_ROWS_PER_BLOCK 32 + void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); -bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11); +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols); + +template +__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) +static __global__ void mul_mat_f( + const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, + const int ncols, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, + const int stride_col_id, const int stride_row_id, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + typedef tile<16, 8, T> tile_A; + typedef tile< 8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int tile_k_padded = warp_size + 4; + constexpr int ntA = rows_per_block / tile_A::I; + constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; + + const int row0 = blockIdx.x * rows_per_block; + + const int expert_idx = has_ids ? blockIdx.y : 0; + const int channel_dst = has_ids ? 0 : blockIdx.y; + + const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio); + const int channel_y = channel_dst; + const int sample_dst = blockIdx.z; + const int sample_x = sample_dst / sample_ratio; + const int sample_y = sample_dst; + + x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ; + y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y); + dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst); + + const float2 * y2 = (const float2 *) y; + + extern __shared__ char data_mmv[]; + + char * shmem_base = data_mmv; + int * slot_map = (int *) shmem_base; + char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base; + + tile_C C[ntA][ntB]; + + T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded); + + if constexpr (has_ids) { + __shared__ int has_any; + if (threadIdx.y == 0) { + int local_has_any = 0; + for (int j = threadIdx.x; j < cols_per_block; j += warp_size) { + int slot = -1; + for (int k = 0; k < nchannels_dst; ++k) { + const int idv = ids[j*stride_row_id + k*stride_col_id]; + if (idv == expert_idx) { + slot = k; + break; + } + } + if (j < cols_per_block) { + local_has_any |= (slot >= 0); + slot_map[j] = slot; + } + } + has_any = warp_reduce_any(local_has_any); + } + __syncthreads(); + if (has_any == 0) { + return; + } + } + + for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { + tile_A A[ntA][warp_size / tile_A::J]; +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int i = 0; i < tile_A::I; ++i) { + tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col]; + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { + load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); + } + } + +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + itB*tile_B::I; + + if constexpr (!has_ids) { + tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f; + } else { + float val = 0.0f; + if (j < cols_per_block) { + const int slot = slot_map[j]; + if (slot >= 0) { + val = y[slot*stride_channel_y + j*stride_col_y + col]; + } + } + tile_xy[j0*tile_k_padded + threadIdx.x] = val; + } + } + } else if constexpr (std::is_same_v || std::is_same_v) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + itB*tile_B::I; + + if constexpr (!has_ids) { + const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); + tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + } else { + float2 tmp = make_float2(0.0f, 0.0f); + if (j < cols_per_block) { + const int slot = slot_map[j]; + if (slot >= 0) { + const float2 * y2_slot = (const float2 *)(y + slot*stride_channel_y); + tmp = y2_slot[j*stride_col_y + col]; + } + } + tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + } + } + } else { + static_assert(std::is_same_v, "unsupported type"); + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { + tile_B B; + load_ldmatrix(B, tile_xy + k0, tile_k_padded); +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { + mma(C[itA][itB], A[itA][k0/tile_B::J], B); + } + } + } + } + + float * buf_iw = (float *) compute_base; + constexpr int kiw = nwarps*rows_per_block + 4; + + if (nwarps > 1) { + __syncthreads(); + } +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); + const int j = itB*tile_C::J + tile_C::get_j(l); + buf_iw[j*kiw + i] = C[itA][itB].x[l]; + } + } + } + + if (nwarps > 1) { + __syncthreads(); + } + +#pragma unroll + for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > cols_per_block && j >= cols_per_block) { + return; + } + + float sum = 0.0f; + static_assert(rows_per_block == warp_size, "need loop/check"); +#pragma unroll + for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { + const int i = i0 + threadIdx.x; + + sum += buf_iw[j*kiw + i]; + } + + if constexpr (!has_ids) { + dst[j*stride_col_dst + row0 + threadIdx.x] = sum; + } else { + const int slot = (j < cols_per_block) ? slot_map[j] : -1; + if (slot >= 0) { + dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum; + } + } + } +#else + GGML_UNUSED_VARS(x, y, ids, dst, + ncols, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + NO_DEVICE_CODE; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +} + +template +static inline void mul_mat_f_switch_ids( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nchannels_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int64_t stride_row_id, + const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) { + if (ids) { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } else { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } +} + +template +void mul_mat_f_cuda( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int64_t stride_row_id, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + typedef tile<16, 8, T> tile_A; + typedef tile< 8, 8, T> tile_B; + + GGML_ASSERT(ncols_x % 2 == 0); + GGML_ASSERT(stride_row % 2 == 0); + GGML_ASSERT(stride_col_y % 2 == 0); + GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); + GGML_ASSERT( nsamples_dst % nsamples_x == 0); + const int64_t channel_ratio = nchannels_dst / nchannels_x; + const int64_t sample_ratio = nsamples_dst / nsamples_x; + + const int device = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[device].warp_size; + + int64_t nwarps_best = 1; + int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2); + int64_t max_block_size = 256; + for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) { + const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2); + if (niter < niter_best) { + niter_best = niter; + nwarps_best = nwarps; + } + } + + constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; + const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4; + const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; + const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); + const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; + const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; + const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present + + const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst); + const dim3 block_dims(warp_size, nwarps_best, 1); + + switch (nwarps_best) { + case 1: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 2: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 3: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 4: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 5: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 6: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 7: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 8: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } + + GGML_UNUSED_VARS(nchannels_y); +} + +template +static void mul_mat_f_switch_cols_per_block( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int stride_row_id, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + switch (ncols_dst) { + case 1: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 2: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 3: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 4: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 5: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 6: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 7: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 8: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 9: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 10: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 11: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 12: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 13: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 14: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 15: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 16: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +#define DECL_MMF_CASE_HELPER(T, ncols_dst) \ + template void mul_mat_f_cuda( \ + const T * x, const float * y, const int32_t * ids, float * dst, \ + const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \ + const int64_t stride_col_id, const int64_t stride_row_id, \ + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \ + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\ + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \ + cudaStream_t stream); + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#define DECL_MMF_CASE_EXTERN(ncols_dst) \ + extern DECL_MMF_CASE_HELPER(float, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) + +#define DECL_MMF_CASE(ncols_dst) \ + DECL_MMF_CASE_HELPER(float, ncols_dst) \ + DECL_MMF_CASE_HELPER(half2, ncols_dst) \ + DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) + +DECL_MMF_CASE_EXTERN(1); +DECL_MMF_CASE_EXTERN(2); +DECL_MMF_CASE_EXTERN(3); +DECL_MMF_CASE_EXTERN(4); +DECL_MMF_CASE_EXTERN(5); +DECL_MMF_CASE_EXTERN(6); +DECL_MMF_CASE_EXTERN(7); +DECL_MMF_CASE_EXTERN(8); +DECL_MMF_CASE_EXTERN(9); +DECL_MMF_CASE_EXTERN(10); +DECL_MMF_CASE_EXTERN(11); +DECL_MMF_CASE_EXTERN(12); +DECL_MMF_CASE_EXTERN(13); +DECL_MMF_CASE_EXTERN(14); +DECL_MMF_CASE_EXTERN(15); +DECL_MMF_CASE_EXTERN(16); +#else +#define DECL_MMF_CASE(ncols_dst) +#endif diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index a92a9e52cf..da2d7b7c3b 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -34,6 +34,13 @@ SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do DECL_MMQ_CASE({type}); """ +SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE({type}); +""" + def get_short_name(long_quant_name): return long_quant_name.replace("GGML_TYPE_", "").lower() @@ -76,3 +83,7 @@ for ncols in [8, 16, 32, 64]: for type in TYPES_MMQ: with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: f.write(SOURCE_MMQ.format(type=type)) + +for type in range(1, 17): + with open(f"mmf-instance-ncols_{type}.cu", "w") as f: + f.write(SOURCE_MMF.format(type=type)) diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu new file mode 100644 index 0000000000..f594d5d51d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(1); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu new file mode 100644 index 0000000000..9cc6772542 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(10); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu new file mode 100644 index 0000000000..317f487d7a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(11); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu new file mode 100644 index 0000000000..dc0033227c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(12); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu new file mode 100644 index 0000000000..0782101753 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(13); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu new file mode 100644 index 0000000000..a23ad6ae26 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(14); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu new file mode 100644 index 0000000000..0fe3f7821e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(15); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu new file mode 100644 index 0000000000..544086375e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(16); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu new file mode 100644 index 0000000000..3b901797cf --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(2); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu new file mode 100644 index 0000000000..56e940bba0 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(3); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu new file mode 100644 index 0000000000..a7665d49d0 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(4); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu new file mode 100644 index 0000000000..3a1dff2587 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(5); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu new file mode 100644 index 0000000000..400fb7c663 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(6); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu new file mode 100644 index 0000000000..954a1c7e03 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(7); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu new file mode 100644 index 0000000000..f1bd09c945 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(8); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu new file mode 100644 index 0000000000..1255ac2af6 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(9); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index cca86a8ce8..a4776ee7f1 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6261,7 +6261,7 @@ static std::vector> make_test_cases_eval() { for (int n_mats : {4, 8}) { for (int n_used : {1, 2, 4}) { for (bool b : {false, true}) { - for (int n : {1, 32, 129}) { + for (int n : {1, 4, 5, 32, 129}) { int m = 512; int k = 256; test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k)); From ed54e32558ffe0f2c3b1d31f3227211fae05bd49 Mon Sep 17 00:00:00 2001 From: lksj92hs <134250687+lksj92hs@users.noreply.github.com> Date: Tue, 9 Sep 2025 15:01:15 +0300 Subject: [PATCH 30/46] Workaround for subgroup arithmetic failing on MoltenVK with AMD GPUs (issue 15846) (#15886) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f0fa9e668c..6f130d47f2 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3736,6 +3736,12 @@ static vk_device ggml_vk_get_device(size_t idx) { device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); +#ifdef __APPLE__ + // Workaround for subgroup arithmetic failing on MoltenVK with AMD GPUs (issue 15846) + if (device->vendor_id == VK_VENDOR_ID_AMD) { + device->subgroup_arithmetic = false; + } +#endif device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle); device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && From 17bc5a815f0bebc1844c99796760cd7df5e9b8d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 9 Sep 2025 14:04:43 +0200 Subject: [PATCH 31/46] HIP: use v_dot2_f32_f16 instruction for FA (#15884) --- ggml/src/ggml-cuda/common.cuh | 25 +++++++++++++++++++++++++ ggml/src/ggml-cuda/fattn-tile.cu | 7 +------ 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 931524a200..394595be0e 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -545,6 +545,31 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i #endif // defined(GGML_USE_HIP) } +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float v, const float u) { + acc += v*u; +} + +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v, const float2 u) { + acc += v.x*u.x; + acc += v.y*u.y; +} + +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) { +#if defined(GGML_USE_HIP) && defined(GCN) + asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u)); +#else +#ifdef FAST_FP16_AVAILABLE + const float2 tmp = __half22float2(v*u); + acc += tmp.x + tmp.y; +#else + const float2 tmpv = __half22float2(v); + const float2 tmpu = __half22float2(u); + acc += tmpv.x * tmpu.x; + acc += tmpv.y * tmpu.y; +#endif // FAST_FP16_AVAILABLE +#endif // defined(GGML_USE_HIP) && defined(GCN) +} + static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { #if CUDART_VERSION >= 12080 const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x); diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index fb2163acd0..64f7d4a1a1 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -304,12 +304,7 @@ static __global__ void flash_attn_tile( for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { #pragma unroll for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { -#ifdef FAST_FP16_AVAILABLE - const float2 tmp = __half22float2(K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps]); - sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += tmp.x + tmp.y; -#else - sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps]; -#endif // FAST_FP16_AVAILABLE + ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size], Q_k[j_KQ_0/nwarps]); } } } From 4f63cd705c7b6f457f36c63fdc053e07f6f3cc6b Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 9 Sep 2025 07:41:15 -0500 Subject: [PATCH 32/46] vulkan: Fix OOB accesses in soft_max_back (#15861) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 ++-- ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp | 4 ++++ tests/test-backend-ops.cpp | 1 + 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 6f130d47f2..c0c2239c2d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3387,7 +3387,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); - ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true); ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); @@ -9145,7 +9145,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun); } static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp index 29bd77d7e1..144ea58e6f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp @@ -20,6 +20,10 @@ void main() { const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; + if (row >= p.KY) { + return; + } + FLOAT_TYPE scale = p.param1; // partial sums for thread in warp diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a4776ee7f1..adf91ab6f9 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6391,6 +6391,7 @@ static std::vector> make_test_cases_eval() { for (int64_t ne1 : {16, 1024}) { test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0, ne1, 1, 1}, scale, max_bias)); test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, scale, max_bias)); + test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0, ne1, 2, 3}, scale, max_bias)); } } } From ae355f6f7108540297d8b7f7ae71d20fe610a0b7 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Tue, 9 Sep 2025 22:26:03 +0200 Subject: [PATCH 33/46] vulkan: throw the oom error instead of no memory type found (#15905) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c0c2239c2d..cb379fe94e 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1937,7 +1937,9 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); - for (auto &req_flags : req_flags_list) { + for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) { + const auto & req_flags = *it; + uint32_t memory_type_index = find_properties(&mem_props, &mem_req, req_flags); if (memory_type_index == UINT32_MAX) { @@ -1950,6 +1952,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std break; } catch (const vk::SystemError& e) { // loop and retry + // during last attempt throw the exception + if (it + 1 == req_flags_list.end()) { + device->device.destroyBuffer(buf->buffer); + throw e; + } } } From ff02caf9eed261423289d1531a56536fbf57bfc2 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 10 Sep 2025 05:23:19 +0200 Subject: [PATCH 34/46] ci : cache ROCm installation in windows-latest-cmake-hip (#15887) This commit adds caching of the ROCm installation for the windows-latest-cmake-hip job. The motivation for this is that the installation can sometimes hang and/or not complete properly leaving an invalid installation which later fails the build. By caching the installation hopefully we can keep a good installation available in the cache and avoid the installation step. Refs: https://github.com/ggml-org/llama.cpp/pull/15365 --- .github/workflows/build.yml | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 43553ac13b..1f226d6ea8 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1063,7 +1063,17 @@ jobs: run: | git clone https://github.com/rocm/rocwmma --branch rocm-6.2.4 --depth 1 - - name: Install + - name: Cache ROCm Installation + id: cache-rocm + uses: actions/cache@v4 + with: + path: C:\Program Files\AMD\ROCm + key: rocm-6.1-${{ runner.os }}-v1 + restore-keys: | + rocm-6.1-${{ runner.os }}- + + - name: Install ROCm + if: steps.cache-rocm.outputs.cache-hit != 'true' id: depends run: | $ErrorActionPreference = "Stop" @@ -1071,13 +1081,28 @@ jobs: Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" write-host "Installing AMD HIP SDK" $proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru - $proc.WaitForExit(600000) + $completed = $proc.WaitForExit(600000) + if (-not $completed) { + Write-Error "ROCm installation timed out after 10 minutes. Killing the process" + $proc.Kill() + exit 1 + } + if ($proc.ExitCode -ne 0) { + Write-Error "ROCm installation failed with exit code $($proc.ExitCode)" + exit 1 + } write-host "Completed AMD HIP SDK installation" - name: Verify ROCm id: verify run: | - & 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version + # Find and test ROCm installation + $clangPath = Get-ChildItem 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | Select-Object -First 1 + if (-not $clangPath) { + Write-Error "ROCm installation not found" + exit 1 + } + & $clangPath.FullName --version - name: Install ccache uses: ggml-org/ccache-action@v1.2.16 From 86587da03bd78df8f4e7d8b111a0c1d2494d6ed0 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 10 Sep 2025 05:33:58 +0200 Subject: [PATCH 35/46] llama : check returned fn ptrs from ggml_backend_reg_get_proc_address (#15893) This commit adds check for two function pointers returned from ggml_backend_reg_get_proc_address. The motivation for this is that the function pointer could be nullptr if the get proc address function changes in the future. This is also consistent with all the other calls to ggml_backend_reg_get_proc_address in the code base. --- src/llama-context.cpp | 4 +++- src/llama.cpp | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 874c6f82cb..3e163001c1 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1447,7 +1447,9 @@ ggml_status llama_context::graph_compute( if (backend_cpu != nullptr) { auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu)); auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool"); - set_threadpool_fn(backend_cpu, tp); + if (set_threadpool_fn) { + set_threadpool_fn(backend_cpu, tp); + } } // set the number of threads for all the backends diff --git a/src/llama.cpp b/src/llama.cpp index f0d4f5f891..92cddccc99 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -83,7 +83,9 @@ void llama_numa_init(enum ggml_numa_strategy numa) { GGML_ASSERT(dev && "CPU backend is not loaded"); auto * reg = ggml_backend_dev_backend_reg(dev); auto * numa_init_fn = (decltype(ggml_numa_init) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_numa_init"); - numa_init_fn(numa); + if (numa_init_fn) { + numa_init_fn(numa); + } } } From 28b5f190ef1dbea5edf82dbc8b4407b721fadd13 Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Wed, 10 Sep 2025 15:29:12 +0800 Subject: [PATCH 36/46] CANN: implement LRU cache for ACL graphs (#15814) * CANN: implement LRU cache for ACL graphs in CANN backend - Introduce ggml_cann_graph_lru_cache to store multiple ggml_cann_graph objects. - Graphs are loaded on demand and evicted using LRU policy when capacity is exceeded. - Updated push, move_to_front, and clear methods to manage cached graphs efficiently. - Ensures reuse of graphs, reducing graph reconstruction overhead in CANN backend. * fix typo * The LRU cache capacity can be configured via an env variable Signed-off-by: noemotiovon <757486878@qq.com> * refactory acl graph * refactory && fix review comments Signed-off-by: noemotiovon <757486878@qq.com> --------- Signed-off-by: noemotiovon <757486878@qq.com> --- docs/backend/CANN.md | 4 + ggml/src/ggml-cann/common.h | 64 +++++++++++++- ggml/src/ggml-cann/ggml-cann.cpp | 147 ++++++++++++++++++++----------- 3 files changed, 164 insertions(+), 51 deletions(-) diff --git a/docs/backend/CANN.md b/docs/backend/CANN.md index 357253f43a..35b189bb95 100755 --- a/docs/backend/CANN.md +++ b/docs/backend/CANN.md @@ -314,3 +314,7 @@ Converting the matmul weight format from ND to NZ to improve performance. Enable ### GGML_CANN_ACL_GRAPH Operators are executed using ACL graph execution, rather than in op-by-op (eager) mode. Enabled by default. + +### GGML_CANN_GRAPH_CACHE_CAPACITY + +Maximum number of compiled CANN graphs kept in the LRU cache, default is 12. When the number of cached graphs exceeds this capacity, the least recently used graph will be evicted. diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index e295f4ab47..17d7dbc75c 100755 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -38,6 +38,7 @@ #include #include #include +#include #include "../include/ggml-cann.h" #include "../include/ggml.h" @@ -106,6 +107,7 @@ int32_t ggml_cann_get_device(); std::optional get_env(const std::string& name); bool parse_bool(const std::string& value); +int parse_integer(const std::string& value); /** * @brief Abstract base class for memory pools used by CANN. @@ -350,7 +352,7 @@ struct ggml_graph_node_properties { struct ggml_cann_graph { ~ggml_cann_graph() { if (graph != nullptr) { - aclmdlRIDestroy(graph); + ACL_CHECK(aclmdlRIDestroy(graph)); } } @@ -358,6 +360,64 @@ struct ggml_cann_graph { std::vector ggml_graph_properties; }; + +/** + * @brief LRU cache for managing ggml_cann_graph objects. + * + * This class maintains a list of shared_ptr to ggml_cann_graph objects + * and enforces a maximum capacity. It provides methods to push new graphs, + * move existing graphs to the front (most recently used), and clear the cache. + */ +struct ggml_cann_graph_lru_cache { + size_t capacity; /**< Maximum number of graphs in the cache. */ + + std::list cache_list; /**< List storing cached graphs as raw pointers. */ + + ggml_cann_graph_lru_cache() { + capacity = parse_integer(get_env("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12")); + } + + /** + * @brief Push a new graph to the front of the cache. + * If the cache exceeds capacity, the least recently used graph is deleted. + * @param new_node Pointer to the new ggml_cann_graph to cache. + * Ownership is transferred to the cache (cache will delete it). + */ + void push(ggml_cann_graph* new_node) { + if (cache_list.size() >= capacity) { + ggml_cann_graph* old = cache_list.back(); + cache_list.pop_back(); + delete old; // free the old graph + } + cache_list.push_front(new_node); + } + + /** + * @brief Move an existing graph to the front of the cache. + * @param node Pointer to the ggml_cann_graph to move. + */ + void move_to_front(ggml_cann_graph* node) { + cache_list.remove(node); + cache_list.push_front(node); + } + + /** + * @brief Clear all graphs from the cache (also frees memory). + */ + void clear() { + for (auto ptr : cache_list) { + delete ptr; + } + cache_list.clear(); + } + + /** + * @brief Destructor that clears the cache and frees all cached graphs. + */ + ~ggml_cann_graph_lru_cache() { + clear(); + } +}; #endif // USE_ACL_GRAPH struct ggml_cann_rope_cache { @@ -394,7 +454,7 @@ struct ggml_backend_cann_context { aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */ #ifdef USE_ACL_GRAPH /// Cached CANN ACL graph used for executing the current ggml computation graph. - std::unique_ptr cann_graph; + ggml_cann_graph_lru_cache graph_lru_cache; bool acl_graph_mode = true; #endif cann_task_queue task_queue; diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 2e47ad90af..aa5913a377 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -116,6 +116,24 @@ bool parse_bool(const std::string& value) { return valid_values.find(value) != valid_values.end(); } +/** + * @brief Parse a string as an integer, returning 0 if invalid. + * + * This function attempts to convert the input string `value` to an `int`. + * If the string is not a valid integer or is out of the `int` range, + * it returns 0. + * + * @param value The string to parse. + * @return The parsed integer, or 0 if conversion fails. + */ +int parse_integer(const std::string& value) { + try { + return std::stoi(value); + } catch (...) { + return 0; + } +} + /** * @brief Initialize the CANN device information. * @@ -2131,30 +2149,52 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) { #ifdef USE_ACL_GRAPH /** - * @brief Populate the internal CANN graph node properties from the ggml computation graph. + * @brief Add a new CANN graph to the LRU cache by populating node properties from the ggml graph. * - * This function copies all node attributes (operation type, dimensions, strides, input sources, - * and operation parameters) into the cached CANN graph structure for later reuse or comparison. + * This function creates a new ggml_cann_graph object and fills its node properties + * (operation type, dimensions, strides, input sources, and operation parameters) + * based on the current ggml computation graph. * - * @param cann_ctx The CANN backend context. - * @param cgraph The ggml computational graph. + * Each node in the ggml graph is mapped to a property entry in the new CANN graph: + * - node address + * - operation type + * - shape (ne) and strides (nb) + * - source tensor addresses + * - operation parameters + * + * After initialization, the new graph is pushed into the LRU cache owned by the + * CANN backend context. The cache takes ownership of the graph and manages its + * lifetime (including deletion upon eviction). + * + * @param cann_ctx The CANN backend context containing the graph cache. + * @param cgraph The current ggml computation graph. */ -static void set_ggml_graph_node_properties(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) { - for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) { - ggml_tensor * node = cgraph->nodes[node_idx]; - cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_address = node->data; - cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_op = node->op; +static void add_lru_matched_graph_node_properties( + ggml_backend_cann_context * cann_ctx, + ggml_cgraph * cgraph) { + // Create a new ggml_cann_graph object on the heap (its lifetime is managed by the cache). + ggml_cann_graph * new_graph = new ggml_cann_graph(); + new_graph->ggml_graph_properties.resize(cgraph->n_nodes); - for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { - cann_ctx->cann_graph->ggml_graph_properties[node_idx].ne[dim] = node->ne[dim]; - cann_ctx->cann_graph->ggml_graph_properties[node_idx].nb[dim] = node->nb[dim]; + for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) { + ggml_tensor * node = cgraph->nodes[node_idx]; + auto & prop = new_graph->ggml_graph_properties[node_idx]; + + prop.node_address = node->data; + prop.node_op = node->op; + + std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne); + std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb); + + for (int src = 0; src < GGML_MAX_SRC; ++src) { + prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr; } - for (int src = 0; src < GGML_MAX_SRC; src++) { - cann_ctx->cann_graph->ggml_graph_properties[node_idx].src_address[src] = - node->src[src] ? node->src[src]->data : nullptr; - } - memcpy(cann_ctx->cann_graph->ggml_graph_properties[node_idx].op_params, node->op_params, GGML_MAX_OP_PARAMS); + + memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS); } + + // Insert into the LRU cache (cache takes ownership and will delete it when evicted). + cann_ctx->graph_lru_cache.push(new_graph); } /** @@ -2199,30 +2239,45 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra } /** - * @brief Determine if the CANN graph needs to be rebuilt due to graph changes. + * @brief Check whether there is a cached CANN graph that matches the current ggml graph. * - * This checks whether the number or properties of ggml graph nodes have changed - * compared to the last captured CANN graph. If so, the CANN graph must be re-captured. + * This function iterates through the cached CANN graphs stored in the LRU cache and + * compares them against the given ggml computation graph. A match requires that the + * number of nodes is the same and that each node’s properties (operation type, + * dimensions, strides, inputs, and operation parameters) are identical. * - * @param cann_ctx The CANN backend context. + * If a matching graph is found, it is promoted to the front of the LRU cache and the + * function returns true. Otherwise, the function returns false, indicating that a new + * CANN graph needs to be captured. + * + * @param cann_ctx The CANN backend context containing the graph cache. * @param cgraph The current ggml computation graph. - * @return true if an update is required; false otherwise. + * @return true if a matching cached graph exists; false otherwise. */ -static bool is_cann_graph_update_required(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) { - // The number of nodes is different, so the graph needs to be reconstructed. - if (cann_ctx->cann_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { - cann_ctx->cann_graph->ggml_graph_properties.resize(cgraph->n_nodes); - return true; - } +static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) { + ggml_cann_graph_lru_cache &lru_cache = cann_ctx->graph_lru_cache; + for (auto &graph_ptr : lru_cache.cache_list) { + // Skip graphs with a different number of nodes. + if (graph_ptr->ggml_graph_properties.size() != static_cast(cgraph->n_nodes)) { + continue; + } - // The number of nodes is the same; iterate over each node to check whether they match. - for (int i = 0; i < cgraph->n_nodes; i++) { - bool has_matching_properties = ggml_graph_node_has_matching_properties( - cgraph->nodes[i], &cann_ctx->cann_graph->ggml_graph_properties[i]); - if(!has_matching_properties) { + // Check if all nodes match. + bool all_match = true; + for (int i = 0; i < cgraph->n_nodes; ++i) { + if (!ggml_graph_node_has_matching_properties(cgraph->nodes[i], &graph_ptr->ggml_graph_properties[i])) { + all_match = false; + break; + } + } + + if (all_match) { + // update cache_list && renturn graph_ptr + lru_cache.move_to_front(graph_ptr); return true; } } + return false; } #endif // USE_ACL_GRAPH @@ -2241,17 +2296,13 @@ static bool is_cann_graph_update_required(ggml_backend_cann_context * cann_ctx, * @param cann_graph_update_required Whether graph capture is needed due to graph changes. */ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph, - bool & use_cann_graph, bool & cann_graph_update_required) { + bool & use_cann_graph, bool & cann_graph_update_required) { #ifdef USE_ACL_GRAPH + ggml_cann_graph* matched_graph = cann_ctx->graph_lru_cache.cache_list.front(); if (use_cann_graph && cann_graph_update_required) { - if (cann_ctx->cann_graph->graph != nullptr) { - ACL_CHECK(aclmdlRIDestroy(cann_ctx->cann_graph->graph)); - cann_ctx->cann_graph->graph = nullptr; - } ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL)); } #endif // USE_ACL_GRAPH - // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph. // With the use of CANN graphs, the execution will be performed by the graph launch. if (!use_cann_graph || cann_graph_update_required) { @@ -2272,12 +2323,12 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx #ifdef USE_ACL_GRAPH if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture - ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &cann_ctx->cann_graph->graph)); + ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph)); } if (use_cann_graph) { // Execute graph - ACL_CHECK(aclmdlRIExecuteAsync(cann_ctx->cann_graph->graph, cann_ctx->stream())); + ACL_CHECK(aclmdlRIExecuteAsync(matched_graph->graph, cann_ctx->stream())); } #endif // USE_ACL_GRAPH } @@ -2311,19 +2362,17 @@ static enum ggml_status ggml_backend_cann_graph_compute( } if (use_cann_graph) { - if (cann_ctx->cann_graph == nullptr) { - cann_ctx->cann_graph.reset(new ggml_cann_graph()); - cann_graph_update_required = true; + // If no matching graph is found, the graph needs to be recaptured. + cann_graph_update_required = !is_matched_graph(cann_ctx, cgraph); + if (cann_graph_update_required) { + // If no matching graph is found, add a new ACL graph. + add_lru_matched_graph_node_properties(cann_ctx, cgraph); } - - cann_graph_update_required = is_cann_graph_update_required(cann_ctx, cgraph); - set_ggml_graph_node_properties(cann_ctx, cgraph); } #else bool use_cann_graph = false; bool cann_graph_update_required = false; #endif // USE_ACL_GRAPH - evaluate_and_capture_cann_graph( cann_ctx, cgraph, From 10d8b2b6b0ac2cae252a80b4daea5da55ab63c2f Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Wed, 10 Sep 2025 18:42:00 +0800 Subject: [PATCH 37/46] CANN: Add ROPE sin/cos cache for reuse (#15912) * CANN: Add ROPE sin/cos cache for reuse Introduce sin/cos caching mechanism in ROPE to avoid redundant computation across layers. The cache is built on the first layer per device and reused by subsequent layers if parameters match. - Added sin_cache / cos_cache pointers and position_length tracking - Introduced cache validity flags and properties: (ext_factor, theta_scale, freq_scale, attn_factor, is_neox) - Accelerates ROPE by eliminating repeated sin/cos generation This change reduces overhead in multi-layer scenarios while preserving correctness by verifying parameter consistency. Co-authored-by: hipudding * fix typo Signed-off-by: noemotiovon <757486878@qq.com> --------- Signed-off-by: noemotiovon <757486878@qq.com> Co-authored-by: hipudding --- ggml/src/ggml-cann/aclnn_ops.cpp | 62 +++++++++++++++++++------------- ggml/src/ggml-cann/common.h | 15 ++++++++ ggml/src/ggml-cann/ggml-cann.cpp | 3 ++ 3 files changed, 56 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index ac2e2e1adf..434023dd22 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2268,8 +2268,6 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx, * stream, and persistent buffers for rope init/cache. * @param dst The destination ggml_tensor whose computation * depends on the RoPE values (usually Qcur/Kcur). - * @param sin_tensor_buffer Pre-allocated buffer for storing repeated sin values. - * @param cos_tensor_buffer Pre-allocated buffer for storing repeated cos values. * @param theta_scale Scalar exponent base for computing theta scale values. * @param freq_scale Frequency scaling factor, applied to theta scale. * @param attn_factor Attention scaling factor, applied to sin/cos. @@ -2277,17 +2275,23 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx, * (dim expansion vs repeat_interleave). */ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, - void* sin_tensor_buffer, void* cos_tensor_buffer, float* corr_dims, float ext_factor, float theta_scale, float freq_scale, float attn_factor, bool is_neox) { - // int sin/cos cache, cache has different repeat method depond on - // @param.is_neox - ggml_tensor* src0 = dst->src[0]; // input ggml_tensor* src1 = dst->src[1]; // position ggml_tensor* src2 = dst->src[2]; // freq_factors + if(src2 == nullptr && ctx.rope_cache.cached + && ctx.rope_cache.ext_factor == ext_factor + && ctx.rope_cache.theta_scale == theta_scale + && ctx.rope_cache.freq_scale == freq_scale + && ctx.rope_cache.attn_factor == attn_factor + && ctx.rope_cache.is_neox == is_neox) { + // use cache. + return; + } + int64_t theta_scale_length = src0->ne[0] / 2; int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1}; size_t theta_scale_nb[] = {sizeof(float), sizeof(float), sizeof(float), @@ -2316,8 +2320,6 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, ctx.rope_cache.freq_scale != freq_scale) { ctx.rope_cache.theta_scale_length = theta_scale_length; - ctx.rope_cache.theta_scale = theta_scale; - ctx.rope_cache.freq_scale = freq_scale; if (ctx.rope_cache.theta_scale_cache != nullptr) { ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache)); @@ -2342,7 +2344,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, // return MIN(1, MAX(0, y)) - 1; yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float)); void* yarn_ramp_buffer = yarn_ramp_allocator.get(); - acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float_t), + acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); float zero_value = 0, one_value = 1; float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); @@ -2411,6 +2413,20 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_cann_release_resources(ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor); } + // init sin_repeat && cos_repeat, only to accelerate first layer on each device + if (position_length > ctx.rope_cache.position_length) { + ctx.rope_cache.position_length = position_length; + if (ctx.rope_cache.sin_cache != nullptr) { + ACL_CHECK(aclrtFree(ctx.rope_cache.sin_cache)); + } + if (ctx.rope_cache.cos_cache != nullptr) { + ACL_CHECK(aclrtFree(ctx.rope_cache.cos_cache)); + } + int64_t repeat_theta_length = theta_scale_length * position_length * 2; + ACL_CHECK(aclrtMalloc(&ctx.rope_cache.sin_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); + } + // position aclTensor* acl_position_tensor = ggml_cann_create_tensor( src1->data, ggml_cann_type_mapping(src1->type), @@ -2462,10 +2478,10 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; } aclTensor* acl_sin_repeat_tensor = - ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float), + ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); aclTensor* acl_cos_repeat_tensor = - ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float), + ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); // repeat @@ -2483,6 +2499,14 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, num_repeats, output_size); } + // Other layers use cache except first layer. + ctx.rope_cache.cached = true; + ctx.rope_cache.ext_factor = ext_factor; + ctx.rope_cache.theta_scale = theta_scale; + ctx.rope_cache.freq_scale = freq_scale; + ctx.rope_cache.attn_factor = attn_factor; + ctx.rope_cache.is_neox = is_neox; + ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor, acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor, acl_cos_repeat_tensor); @@ -2504,10 +2528,7 @@ aclnnStatus aclnnRotaryPositionEmbedding(void* workspace, #endif void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - // TODO: use ascendc - // Only test with LLAMA model. ggml_tensor* src0 = dst->src[0]; // input - ggml_tensor* src1 = dst->src[1]; // param float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; @@ -2538,15 +2559,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; - // sin/cos tensor length. - int64_t repeat_theta_length = src0->ne[0] * src1->ne[0]; - ggml_cann_pool_alloc sin_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float)); - ggml_cann_pool_alloc cos_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float)); - void *sin_tensor_buffer = sin_tensor_allocator.get(); - void *cos_tensor_buffer = cos_tensor_allocator.get(); - // init ctx.rope_cos/rope_sin cache - aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer, corr_dims, ext_factor, + aclnn_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox); int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1}; @@ -2556,10 +2570,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; } aclTensor* acl_sin_reshape_tensor = - ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float), + ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); aclTensor* acl_cos_reshape_tensor = - ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float), + ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); aclTensor* acl_src = ggml_cann_create_tensor(src0); diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 17d7dbc75c..c5fce8dc91 100755 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -425,12 +425,27 @@ struct ggml_cann_rope_cache { if(theta_scale_cache != nullptr) { ACL_CHECK(aclrtFree(theta_scale_cache)); } + if(sin_cache != nullptr) { + ACL_CHECK(aclrtFree(sin_cache)); + } + if(cos_cache != nullptr) { + ACL_CHECK(aclrtFree(cos_cache)); + } } void* theta_scale_cache = nullptr; int64_t theta_scale_length = 0; + // sin/cos cache, used only to accelerate first layer on each device + void* sin_cache = nullptr; + void* cos_cache = nullptr; + int64_t position_length = 0; + // Properties to check before reusing the sincos cache + bool cached = false; + float ext_factor = 0.0f; float theta_scale = 0.0f; float freq_scale = 0.0f; + float attn_factor = 0.0f; + bool is_neox = false; }; struct ggml_cann_tensor_cache { diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index aa5913a377..d148174f1e 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2353,6 +2353,9 @@ static enum ggml_status ggml_backend_cann_graph_compute( ggml_cann_set_device(cann_ctx->device); g_nz_workspaces[cann_ctx->device].clear(); + // calculate rope cache for fist layer in current device. + cann_ctx->rope_cache.cached = false; + #ifdef USE_ACL_GRAPH bool use_cann_graph = true; bool cann_graph_update_required = false; From 09e72a037c77b6823fde9e1e4cf28e815b9c41ab Mon Sep 17 00:00:00 2001 From: Jesse Date: Wed, 10 Sep 2025 07:28:47 -0400 Subject: [PATCH 38/46] gitignore : Ignore vim swap files in tests (#15901) --- tests/.gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/.gitignore b/tests/.gitignore index 620a48ee44..cbc381606c 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -2,3 +2,4 @@ !*.* *.o ggml-common.h +**/*.swp From 2cfef4d117d67ab1dec002915b48a15d11ee1973 Mon Sep 17 00:00:00 2001 From: j-k Date: Wed, 10 Sep 2025 12:51:28 +0100 Subject: [PATCH 39/46] media : add transparent icon svg and png [no ci] (#15891) --- media/llama1-icon-transparent.png | Bin 0 -> 14270 bytes media/llama1-icon-transparent.svg | 77 ++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+) create mode 100644 media/llama1-icon-transparent.png create mode 100644 media/llama1-icon-transparent.svg diff --git a/media/llama1-icon-transparent.png b/media/llama1-icon-transparent.png new file mode 100644 index 0000000000000000000000000000000000000000..432d6c2223bb445362eadf946ecafea5ea4fd558 GIT binary patch literal 14270 zcmd73`8(8K_&;vTk}!j^Z)1yOe<5QDV-J-!YecpZO2(2cV(i(X82gfttt4g`OC++E zB^f(o9sAgQ&h-AgzkkB_`uyN>U1y%pbMABB=U&d^aX+K3o9Z*2OzW4*sGpn(x^}BR$s*=-sO%OrvAHkD_{3WNIlCFjQv*# zvhBS?qEPG^v^LfDPrjz&=%Q4ZI2X}dFML#i_PvU<(!5Je)Oio;ai-g*qtU_`Y6JpE zV)}b15S3?PM?7=MZ{cyPFOdQYWvVYQ{`_3=#w}6FKzhzKw8GU#*#CK&|MR97qELpK z&ae1!aT8D^%BAny#H;d|k&#W;$-vhy4~M>Fxc2F`p=oVi{B-vam$o_6yc{^I<2>a= zq=7=!-A#^u{pl5AR-Po~%EsCIgdty|Q1N}DerebEzA5Zw!N=V35zZ$s11s}!&Xew99Wq5aFi4!9PA#+%6|>ELotb9e?A>@;T~*Ldq!ytK}dv0WL`XlC@OYJe^60fzx%K&$v~;gt`oGLd_#Pp`~} zY+Q-^iYkbrwc)N%-4R7hIpa+65+<7$9r7xC^-(Cbw2!&Q4!4U<1wY)UTbV)+sQJh# zTv%p!)cf3iS*7Luf`DMz{f4X`)&H65!d~TTrL}`oiMVwtw$4gS5j%hjTQ<(}CDtu{ zx{$XMd38d>j&($fW{in3fTBU$mFBVUNpFDlo2}#~WTm1uG;5z@tDr!}VzL=un3cDKW)xP)k5U;J3 z$e{NRn~0A?N@r^J@mxQh)8?Lbm#)HDH~3K#)IKV8Cqlt8FKsA>3XxFMV*#f6b0*LK z+=Cs8$^L~Gc4ifb9wi_VIzKNFyY+g6r*&|qc6i;>zC1aO(I9Sbu6Hi04{-=$Y>~pe zT+IRz&&A{g6gQOSWdDM~Ar{%=q_)H-n4}Ay#9Ub|wws#3DzL6AFYR<;gd=)lEyo&i zyoK7=@c&sk9{qu9iNq)p(ymp^+6)!R7H0;ZPa!k~a zFr^ij4CLY@q7y@`7_bxT@tkT3ZmSDdv2Y%T9^qfv`&GzrgA0rFHnkI(yGVP?5SXo5 z!$`CZ&lM*B#oc4NwnED{75M}AYX8=giON+0aoDL*nDm<)Fvj3=nejnG^D~1^(c= zn|XXF;cyR=_M$RNh=13QNS|0XAb}j!crH!pv(sa6m2WQ{*_H!bCX_Jj_^4ZGol7!? zf2#-5Dan`Hh5P9m%+%h1QAmP*%&8&WAx5uO+xbT|tR3F(J{g2Rgz$@xIAeto)z`wL zsm;+t+U(LBJqF*++iqcmYmIj*r&8L3_-Y0*=Cq=+<}7YdE^c6;YoF$iak+S|C|M1y zLK2i1OIcbQFUEbSuDz#BOwl^v%O#cKp)iYLtu*&iQ+_i=H7otHFp_?I`QA%pTgYuI zeTTwjK}-w$;p-*HiEs^31Y!ssXBQv4z?5-jWg4AGU&HbYk6p$p&^iV!%jOMG#g@fi zp^+}CL6#sR4ESC89xi`yrwf@okszMnlrB-bGRdx_1yLmoPDL?YA5_5{efT*i-N!R2b+Rs z?{H}=$OrZrmKR7!g6XM)olKS@?PJQyeB_VUoHpLlHnV7SPz_RYPTmJNHHJ;Wz?NXe((wEJGT4e6>aFZ7Vqtf6JyIiSR&}}QaE7DBiia-_TWn~3W!b9FzMd<$J6vw7B9Bv zcP&i>kB6Tcn>$f-_&J}8yX-t$&k}GK#U4s)llA#&%OZ{+gbj|C!jj;3mc#v#5cy5~ z-eoJ}Zr9Ahbfob^zeEV0|tT&xbXscV|R3hb|FZ4Q3_ zNDh&}b1|fMoKgO`9Y68`cyov=o=fL}R8IbaF`f&io2Qi)6neUDYJh8vys z_5VS>-br8|uPsI?cay+UZp+Q-=w-B=q7Nl{ylLwrL;}WCf0bTWMg1VSgd7YXl;#Z- zcWy5%j6L_>J*q&yGPvQ|7xVFtzpP#IAyURt47JARNVsq<_`V@G`k$YQPq187MsXJ(~R|_V3 z{uxcfJBOB|;gD|#$OVz7A1qcj!}n6w3w$7lZco3PP(3$u3k290?fQfs>Kg1N z+RNcx1nc_ma2><}SY|2l%}gi**`(8tuxzz1relvnJ3*BD^v*u;8Hg;i7Tnt{XN!F~ zL?$Je@tRl`IysFcw5&^LuiVyr2Ug+V1F#R4;a+)2y{%Cg`HUPBeKlGY!NrcoM5QR6 zA!Da=NL{d1S;Vr>;fC%_1Mqsg6^4x!YJ~I^e|*PUm5y^BEq5_AgRo+-yTv!|_&6j=syu2UeZQW3X$I2C8QK-RDJ1E!Jqb~=xK{o0#f_s^f?5G+Q4)>Xag z_&bnK3CO1?&QP%>ZD4wq zPCZ;o^PKaURMFtgcu3-9Ab8Y<8!^1?-lre%DZlzZ;!b-dZXr8Y>4BNaFXp}HZxCO< zNsOHh*C4)JV8~oV%PLMO5vib~s!aYj5IknUMN$e~`+f=z7LPTz1d)+}1i(aRMJYPiQ}S}Z-xb_)Gkg>eSMMuGYR3WWgYa~9UaIWYy9;K!Gr z1H?so9*YxcD71a^k=i;bc}f4@i}4$#2G-qpJOc${>L;Fo$$rQCh|&;IYY>)&{M|Xj zfa!({d$m*I8Ez$vC*wqju@Uq)RM!lFeg<+DY+UcbUma=CooXd0WznZ2^>o6!H8yV- z$e!Kf32df`-!)ygm&sNAd8l=9zb=Rw61^=E1{*EV>cvvwA;TW*WjhVOJUf5gz?J^B zZR(feokjoCLvK?jp7M#*HLrlx-y4c{pT?GztyKmZt|Dqp!?K*8C5&?TCL$wvdW$o@ zP5FKqesXp!*=;3SeKeufV1=AXjS3Wf@%qT`RIZWV%tM@x4Q(7i>POP^&i|qDaV4eiK(n zQd}*KydJ*1(0ev>Wz%9GRt-X!wLJRiL=rCcgeBND`T<|A4kL8IUmsguGNFUg> zE5<#zN4h}@_WzU1bXs4+(KA&YB@Fh42+PuO!cpSSODNZD2;8rqrx0B1jWr9ivG`hp zZxKtP;@-x*g+G+BOF`=W;JfW{j{r z>@1(1W(XwW=yUfm%S|)N($LSGy^s-lHHO^yQOO*S!TVN?Jpv_+qCZ5&$ris>B^%%`cw$@CFSECXS-3SVC(093`SIhn&qMU9AHug%m`e_ zf!Thann>FyPpI%D9UoG(hQ{I8a5|j9Vk52hd{OC|4!h3NZz4;i{8j#Q7&l0LwHy)& zs+;I;ZRTJPiL(9jR^Lybn8#P@uqW*!0|+N6Tlo9zKoAtv&acwM@Q->UF2Ryi_;MFI zXZqJ%y6I-Vm8~68tiRPdUOCaEKext04TCY|O&mQFU750ZZPDzcv}5=f1~Vzr>P_}H z4Abz1*EG&GVT2I~TI1rsTW#Lq8V14C4czw#AQQ{quBtuprs?JJYt)(UrN&OU9-<&ok+7eq6`XsXQ}fmObJaH|l4(Rr4yaaG zmU4hUT>;|U?j!B=-nAV8ub%2SycP`h6UdR4Cy+Z~JZY6;+4!SOmlV)}LWvk=v!`*~ zjzRrVq7Wsfg#T{-D7JpZiFBk_{7bjsy?1yPhj2V!M>~@)_m+6`N0D0N_flVbfdp+C zwCXZLq_|6M9W-2D*v6b0@*A;>p#TMXNvp2EGy~4%yYCmE(~OJO{&A+IF(?d(9x_IS z8?(fk8xJNe^JEh&)MKy?2*kj(uq*+%Kb_+B*bN($o43^0E-=3cBk*`=2YhZaHsBQM z$E#YheS#;C8DIyZE41$9!npvqBM;b26XX4Fqj3{ zb)0NZgB^TaH>YJK>kN6&FGs64A@(mj2mj0W$zi0gN#+$F0)_!D^@>(qxy4i>Sl~df zQijz^Jaw%Y0|eaE3(MlE4MR5HIpU4cyw)+$fYkw`AmR!D3`J>1<7Hs@==j}$NJrbx zf%*gpHj)L|6|1Ccj|thp znIaI!ECY-XkNQuWFE2^;`lW>(cTJqjI0>M?N^?U%mey_VH_j;bcKEui6AN^V8u$R@ zB;pmFYl-If?hFBsy^8w?ZleUVpD~txT%IXNTF*=2HH#}YOMFRZfk)PDH6K?s>d3-r zR8Y7sSurYjz=C^ZDo!0YjvG%O*Fmj09nFwD!8Hy&hM-~c5Asg+5M6)T@$xtB@pOOM zO^cC*k*VT8jlzX}M)dA>$#TqMo{&Wh@4!W#PM)IwWS_(+MyY(VAJVfs>nY1VEmS+kN{w{L&A@p?Sv|?C-b=vW zfcb&+5o5WZfD2`4TGFG+o_!to=xeaqKDaTi?RVX1>u10*;QDdj-H)XFhFUq5DV1-2 z;d#IPx2R(>TJa4ym2Pu~p@B&O&7M4R&hrOCk^t*N<*`+IQia~7X2T>uot2B_q2Dz& z3e=pM>EhpIX6*!ZmF=$|ER+Ev>qc@y*5B8>UOju`er<9m$^0|w&E~D8sd7#a%3YGF zBuce17S~JPpHL$8brfzuh9LZIT+-CqBEAbc$4D&7Cpnu0p=YJiZKt#7;>G4~p(bc+ zi;+p5nj^mvQaKgQa;cC2x+urce{vRwI0ae$t+@DmA~nW5|8`8}W~m&Wath%sss8gs zk6tsc?RK}NfBwgYn8$Dbz3^FaL99IPDjgb|-E-to^m`Y3@fDI)g#&Hcn@7j+A{lr0 z>`wgIL`qOZ={}=9_KF=`BTRpTJeT%>-A?wv9-+0FJkQpUTd)_)i|<=Dc_R68MCt z4a2fHK5rTwXp#a5u^zHT-qH#eD%`&h8xQ>5bzqp%p}`xSMrIp({f*nWy;LY;ZUI09 zJrP$B;@_tVP}?fDRaA%K=+eiZ=?W*YJ{xU=mSuXXLIs38!;$fIR!lE4Mx^rxj_HV; zi3KiwTMr0Co%^A>#eG$?%1n}OjWObJ_N(7dX`Db^x?sTl>b~8JH}Qp2_@T=VOo#jW^op4V#{|?NKUD+z_%RtJwuLZ47Xrx8CpDBOd ze0$5J>8{bz;%R}Oz5VYCCq0%gv2RDQnued81wRa05+s9T-+ME*SK+U~GO7eTMGGa2 zOFm-hG1rak zu8ueEt2fgO^2nL6e=ELE*>1|+yXbLuLlG~+b(=LdAx}Kl&?gSJcag`^bc-44Ck?bj zbJ&qH%#f+2Pwcf0IJsT^LQKeVD0u0n$kv(&|IT@UTrCVpkT8|L=bNcQpS>Ng!NC`@ zr62QkL=c`<4r0v_-0hEPQ(pxFkfRrrt}*H2QF}6~=G#(RmisCf0~trU;JO7~E2^sX z6{L@l{kyr4zDfA+16f3aZ0~zJz%Qc?4HAZUjInlWA2V;(SH|c7b_YN8U{S84H=^Fdhy+ZJ8 zm*RYqm0nx|QgaC{R4ob+*LnaxLC0i5$H{NZ%EVfE0kHl15J2zfyVK22Xf8BTo!j~d zb$y5H_y^i^!L0H#rX4+jNiUWkQwe3R0x=wo2`lQf+Zlxg8Bd~g-njgvqN*U7 zLS3t`viA(98(#o39(JP-9vJOBV!)8MwHQm3Ek#x#Bfeh+zfFE9KZK+Pox!xhbwwzb zE8n~Y;3Rb+Sjz5$K2iUcC92&z8n(^_xRnXGj$;nM0^M!QnPJf$G|cO!P{Z#IG!5X2 z?-)QxX?ScyhOAc7a9MZ>SQJBUm{G4eFv6Hln6%ZMOJ-C-47D-HXoIIM?!496S>p^m^%K6S zi;uX0XKr5(qn#eRO{iBHG781$Ag$mL^cQRqgaE2%f96BcYiOZqM6YU^ zb)U_eu5$5_?YVLKl*@s*K^(W8Iwa(_orB=HiEW2-oGdH!Db|t-!qz#?Xk&hYA3;~C zIFS275c9b$)5KtZ&X#e`V4+$%S^l<`Fq(k$y2hj%!gn_(l-02|_cxMC@_m2!gfKNc zC{zy_m-7j$l&RD_7FcZ1l&#)q%1mLDa`GWo`Kx`H4W-4tuLq8UdP4#CD`x*?IY|ol zR6sQ4vp>o2UCSL<-S*0N^8UQ4&&wb5-J=L_MKSxm7>4jTY7PW8R`m34XDkC zYCKb6As_})sQjqvpetY6cC9$~=_DBsG5NKNEkj;q36^q~sP(BUa!wV4cDW{s@)!7k z9peBh#;f4wu1RCMDc$)s{i+Of7&856(R%mLI&;W|%m@ZwiH4X6k1#-NDbG#J_ZnwE zM_i^`-guuR4uF=D&o&fFy>Q{;B|Ao;7in+e?s81KU+3U!pypA+MA&)I;AsHEq$nuV z?chTsgJ0Z+VYVVRb9bS2{Np=+UDP3srfMqzw=Ss9^UXLlZ?huDk!tPUm0yzZbcbp` zQ1)zJ4(GKKQvl?ndrEybNdHcJ&VYw#=xJnX~I@uogj{;%`($kR#IcD)I^_UX^~|K7@i&=5FLAU*kCmHT&~V}fO2wdG`{!{kLS42POf@~Xo-lj++qA*@mp2Ei6 zLv*AN7xRTE9x+H*!G4f1fl1_po8d{rAEI4D@ZhDC=bYI&_(NB0wwl9-^8BiS$FW;n zBd&X*-n;zyH8MBZFXJf(uJRahX|=FP`&ct8xh1 zT>3aeLBXL(?ge<4E0+d~!cZ4zxw%)+;Fr;f5AG^pg3h72x4JWa*TmW^>$AS6$~dLJ zJ7Wdh%TVcBTbk+D1ueh}Asn7i+b*G#v;*xF7FYITk6@TGWQRRS018uiCg-|od81zUe-PBzAYFvQ#g`4?07ev~0>@{UK`>At3gGxxD zKmN>{W5IB!e1{8*2K9hn$|5~bWWc?qNYz~=oS7KwP2VKQe|!NffinH-Z1!oM8i&Vo zNmRbBx-2f|Fr`X_Kob@F#)hg3RNbZe-)T}d_!p}*#Q3(}lKfJ($qdA0L_h6K8d&0T z=@w}9&2P5`@D?a9&U2V?xBf{aQ8xG%t61u@+x{%WehiHt~o#rvsO7Y`l1F2Kl} zdbAwD)Y;^eMK~0TD||$7y|AK&!zl2onReoPO78w9-=3?r!?NxGTvx7Vy=UECV0vae z!T&C3B#EL`FE?uKrD_NA*S+akBa&%!$2eN`)NXnVkiXd+k&9}20!5rBY{}?i(rxbTPbH>2#*^iMPFilwBatH<*ykOv zPch45rD&wW-?@LVgEm5>OOdc>{%LU4%v@gGR~O-kQY5&_lhOm50kWM~DUDqKsw03g z)`Om$dlSirDE<3ODnwAXup)Qn92miJV?!G*5WX7+9A5YIvr)%tUawRQ-{^h%>} zVYyJVB^^fug4;MOi?(p@R`k~2ExHk`nVC5o)K!+rpSvxSfPS6hms#^TpS%R#r4M^s zJfXvd1lV4Ow*01fvW$Rh^eBXiQ(OjIBmLA_EYTWKa7B@T*Y-sNUV|!k?h>e+V19*K zB$6dDB>X7z$FwE@K~W&|kDEXLVUN=wla!5Qbev2dnu%WevAElGLm)7Yyk?ebpZXa% z9xlW+b(ZXC02|%B7PeOC29^B{=y9(jcYt>$Vg_>>Bm`)$t$#GYUB30^GASNxAI?;t z#@gX0E+T9_2PlY}G~C`YLK;JGA*fdu3BbRI=&NCC?eQDZ2DqD8k~Knvi`$#!cRq=F zcelqyV{p+2ni%fVr+!c62Y99ZH=f}YMCCVJOMcGI&76f4|GIrJC!o@^0I5y2mY;uU zOGN!GO9MWoSduSn3e8+a4wE53Bm=1~qf=($aNDL_|P2=@QVTDSFG5K%WF7fp_l5SXRr<1MK@6wmN-MKp+ zg2pJk`*Cjx*z_iXdt2m{%4zXfKWNkP?aL7_u?<;XK#*Iv=)g~JR~NtvmoI<$w`L-( z`i|^UFBwUx2Rc)3_d!9k7zdO8&4Y;5MKJ$`<9>kH7ra*_7&6qBBDDuo5BN7t!R(UY z8L1wUextS>N7gzxHsp~9k6!{2%BM3Q@&Gat5e-1OObknXR;3th8c;{+uZiYsp&fvr zKtzKen4X#aRGHkICJx?gQ6au8s;>M&UhO-OuqqLM#^n4MEm?uE!5r~+Jz(wBT=-vI z1g~o!8{4Kj0q|>7|MR1B9rCLAnL)Db`4@;GEwXwL0C<*!=o9jXjzd4^TVfRiHF#G_ zJDZ-=G3C`H1&^CdH~!W+*k7AydKolvV*h?Z(~lDDNV_MYFDhD@urN{2y@y89ey$!o z{j$QP+Hf3v95#1+JV-L(zE=D<*?g1PMY8KqoeBn9EYdPF9f#+XXmiex%?9}50?vI6 zQNGc8|0}!b>m&a62bTJ(8o}-h9fub_LyhO%s{$EDt&cwCp9!P(fTbnD{z z0Wpt40W2zs>Ws=NzwHz`e2~wv?q4GMbx0t=Et>#BA7}&sN87UksVzUvhG7J!0bsgi z56<~df?xrJ$k#Hv$-cH>Y7l{qJ~LP-Om{)2f)a#LMLRZc-LVqCI8>K1nEwt;ah{CT-84| zuHYg#-i$ZAX!%U@Rm;D@wGjpMU?gL`jSr_sj19-oTji4qVo)OFsIj$X!}50#oya+( z_HV(y#p@hCjNyj(^9+8}zlXv~q^#EXlC1vk!D9;^uk&~gs6p|AM#C@*`sRjQU$;3D z_5JN0gFXPn9Yy{N){7|(kt0FN=+~7BZ|~wqfdnGGiUfX1btGkipRo!JT8fE(m2x+gn5o<8!QC2GjkRW*n?+^;lN8>4>a zmL9v4&yGcbmRU^QHed9uB({mfoG6x#oQ-+U z0!Q1y`$yF~;QAuVpeRu&{`)Go0h_%;fT{fi5pedJ^p+6xYQD1>?bK#tnDjaxy2ZC&05oO za7sTSJ1%U-~JX$4|N`jru=1 z&LB36zrn9636Ffc=Nkt2&rfzr*&jE(VkOXBvELqH!yFZ)GgMxtvYEE2uzWK zV_EkeC=7z8qvF4VZl(b&7q*AWe=>NRyInzB#H^7_6HPvtGqC&-e!4Ygj~xFo=YVM&i49f=4Wzo@|v z`j^io#rF-CmA;S;?ZPU^Ps&Z^d+)v&hl_mqkleu*m(DyzjVlnt{kcaaXBdf zn6}iE$MDmZjQeY{3Eq$k{yHtN(~o>$MA$k*9eWS{{o`2<{Nx&-R@ zTAJ|^Ts7}-3-pCC@1<;~_>%ok)e3vwLW-#72kdvXRO>Z~n}1+jfHHY`#l6_RXrYj$ zCmEN9GD?$wv@rnyhC=@lPa_BX)yZ)jPskMB5Jw`cT=@;xCNmdr?evCI9@4-E8bn%^ z9=xEdk&!KJ7=yR}{FdBa>GW&4YW>b+9Ju8b1aGLIu?>r=_1|{~bDVYWjx96zppkkS z;bXFg@}iBmQC}JYm+H_Xl^Z4*dPmQ37l{Ss(KD?jQXo3&h&)mxSloQnbfKlm&w2K{ z-sD|KV^<~g?lTa1v0>arr5Mm7O({-2NC^I5aVJyXxpx8hVKwkAQGmn=eVCw)-8Sva z)u4*0^5paKzMPshCja~o60}|_ZJwR{kF$c(b{ehdJ9X&2+xW9V``>a3#evKhRIv(Y zCy6){oJXi=+m028t22LpRk%S5hoahUz}u^Ewlx)!ZpWPtv$H{?61>7n7hJd4_D$C30fb%)y)9MQiwOKr+9t%=D9UL9#+=K9`gFD^7%T%eG``s`!0WCDZ zyGyOwV7BoQ(wPKMdD@afV&JxB3PL2Z0`NB-;>~B{7y?E&L6brifp0tbT8RNa0URlc#rr83GWgWJy?&uHhCz02j@*t4 z7@Yk0EHiuSMX+1Q4Q~wcr2)W>pGna8lxH&?{b&Fex0Vh3nSW1~0=mx8_n19)44~~1 zgjB%4fy8472cT8p zo+p7ezIUK)tSJA`U;F5e9Tpk=-KT%6U$@o>w6fKt*+yt|lNzW`^iHYo!3 zu*%coNdHU-vQ*-y`%3_$x0~ZhrVztRO!eCtep@pDY0;h32B37gis^8_u2QcyJ)S8^ z>2s|DZk{?H?8a8`!Py%BH{fg)#KnVtRMmh}fe|kpHBgldWcgFW_j!QFZ>cr}>#MD} zOwxjxO!z6C!Ue0-_6m7hSkt z0BZ=Gpn(IN|LhDBSICPYDqq?8P>+ly;}Q=irQfPXJ5QTwsj$crc{>X^>)u@KcDO@) z0N$7C&*|D-Cim+$Qmhcsu&-6|xKzY0mufw6C!UN9)Sv4SHP_$pK6c>bpb*A96RF9f zX-KM@)0R|LpXC!ydPwd#w;ap0yH67jaPMpTjO!M9<%g1>Hva>7NESuH0EN@9AI@6o zbSq0X!X>q@aSwT`i3!j1GYlMD`wG24(^AB9VNC-&G1(rt{!It<_S?T!A4rY|uY>i0 zmKE`s&7Q}z`J4h;!^1Hv_EcL8JH0n7cEiE$@L&Qz(Dn?mU{c`=l(fImhLfq3%|1vM zFb7;ZpU6#iZg1B{w=qgYf22@m9RFU@-2_No5@gUJg_B&ax#>JbIW*BawZP{rM}e5Mq-)fYMF*ssnx(PG)y{y*&?pxvo{ zujF6$_d1fUR0!$E8&>NEN3I9v5qP?nENBkmbVS_2HzUjdx{1=e7P%f|SB;FPSQlwe zh-#&3NniF3%P^=1k_3W}1xgh7X3BA)_$z0&=slS(9ETr#UoIhco{GSz!M6_HrJwnt zM>FVmo#;D!7iNEDgRd|iPzI{Ul!TAK4S?r>?^yCn17Sj$=cZ^X9s2$wHCvwf8}6mP zLl7tTF#i>XG398bJ^o<&NXQb-z$gw`p`q`f62{itT$0P-m=SwjUDM5HY&Mx?+Cr4;NWo2gFHl6Q*kK`AR^?0_#SztE1?)Pu*Ut<@b?+=AA zx2;^A_7rT#4UXhPjr*G!R`9fO4($%kahjFs7vNQ->P|q8jhizJOdnYc)IpAk) zMr4n4KeaVL@@zuK1y{7YRit8@%V<{UuU4WDPLuhn))O@8w?oyBSBB5^+8xa}+!%(m z7{#8au`;B;6qR>!8oPXnYp6QO=HL#&;g+sr0QuGpzPw8K#14wq|JT39`hV{;vCvZ% Zm8;j&#LGp%r*{+-hPtMRkJ@%G{|_<)MHB!4 literal 0 HcmV?d00001 diff --git a/media/llama1-icon-transparent.svg b/media/llama1-icon-transparent.svg new file mode 100644 index 0000000000..e28203f4e8 --- /dev/null +++ b/media/llama1-icon-transparent.svg @@ -0,0 +1,77 @@ + + + + + + + + + + + + + + + + + From e7b6d83b524bbc24a1343d862de6dba8e8eddbd6 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 10 Sep 2025 14:17:09 +0200 Subject: [PATCH 40/46] tests : filter out no-ops from coverage report (#15900) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * tests : filter out no-ops from coverage report This commit is a follow-up commit for #15745 to address the feedback on how no-op operations should be filtered out from the coverage report. The feedback regarding the UNARY and GLU sub-operations not being handled I not exactly sure what should be done. They are included in the coverage, for example ABS, ELU, EXP, GELU, GEGLU, GEGLU_ERF etc are in the list of covered operations: ```console $ ./build/bin/test-backend-ops --show-coverage Operations covered by tests (89): ✓ ABS ✓ ACC ✓ ADD ✓ ADD1 ✓ ADD_ID ✓ ARANGE ✓ ARGMAX ✓ ARGSORT ✓ CLAMP ✓ CONCAT ✓ CONV_2D ✓ CONV_2D_DW ✓ CONV_3D ✓ CONV_TRANSPOSE_1D ✓ CONV_TRANSPOSE_2D ✓ COS ✓ COUNT_EQUAL ✓ CPY ✓ CROSS_ENTROPY_LOSS ✓ CROSS_ENTROPY_LOSS_BACK ✓ DIAG_MASK_INF ✓ DIV ✓ DUP ✓ ELU ✓ EXP ✓ FLASH_ATTN_EXT ✓ GATED_LINEAR_ATTN ✓ GEGLU ✓ GEGLU_ERF ✓ GEGLU_QUICK ✓ GELU ✓ GELU_ERF ✓ GELU_QUICK ✓ GET_ROWS ✓ GET_ROWS_BACK ✓ GROUP_NORM ✓ HARDSIGMOID ✓ HARDSWISH ✓ IM2COL ✓ IM2COL_3D ✓ L2_NORM ✓ LEAKY_RELU ✓ LOG ✓ MEAN ✓ MUL ✓ MUL_MAT ✓ MUL_MAT_ID ✓ NEG ✓ NORM ✓ OPT_STEP_ADAMW ✓ OPT_STEP_SGD ✓ OUT_PROD ✓ PAD ✓ PAD_REFLECT_1D ✓ POOL_2D ✓ REGLU ✓ RELU ✓ REPEAT ✓ REPEAT_BACK ✓ RMS_NORM ✓ RMS_NORM_BACK ✓ ROLL ✓ ROPE ✓ ROPE_BACK ✓ RWKV_WKV6 ✓ RWKV_WKV7 ✓ SCALE ✓ SET ✓ SET_ROWS ✓ SGN ✓ SIGMOID ✓ SILU ✓ SILU_BACK ✓ SIN ✓ SOFT_MAX ✓ SOFT_MAX_BACK ✓ SQR ✓ SQRT ✓ SSM_CONV ✓ SSM_SCAN ✓ STEP ✓ SUB ✓ SUM ✓ SUM_ROWS ✓ SWIGLU ✓ SWIGLU_OAI ✓ TANH ✓ TIMESTEP_EMBEDDING ✓ UPSCALE Operations without tests (14): ✗ ADD_REL_POS ✗ CUSTOM ✗ DIAG ✗ DIAG_MASK_ZERO ✗ FLASH_ATTN_BACK ✗ GET_REL_POS ✗ IM2COL_BACK ✗ MAP_CUSTOM1 ✗ MAP_CUSTOM2 ✗ MAP_CUSTOM3 ✗ POOL_1D ✗ POOL_2D_BACK ✗ WIN_PART ✗ WIN_UNPART Coverage Summary: Total operations: 103 Tested operations: 89 Untested operations: 14 Coverage: 86.4% ``` Refs: https://github.com/ggml-org/llama.cpp/pull/15745 * use of ggml_op enum values instead of strcmp --- tests/test-backend-ops.cpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index adf91ab6f9..4a882ab072 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6807,7 +6807,17 @@ static void list_all_ops() { static void show_test_coverage() { std::set all_ops; for (int i = 1; i < GGML_OP_COUNT; i++) { - all_ops.insert(ggml_op_name((enum ggml_op)i)); + auto op = (enum ggml_op)i; + if (op == GGML_OP_VIEW || + op == GGML_OP_RESHAPE || + op == GGML_OP_PERMUTE || + op == GGML_OP_TRANSPOSE || + op == GGML_OP_CONT || + op == GGML_OP_GLU || + op == GGML_OP_UNARY) { + continue; + } + all_ops.insert(ggml_op_name(op)); } for (int i = 0; i < GGML_UNARY_OP_COUNT; i++) { all_ops.insert(ggml_unary_op_name((enum ggml_unary_op)i)); From 33daece86b65607451d0d4378d2d04ba6a20ad55 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 10 Sep 2025 15:39:57 +0200 Subject: [PATCH 41/46] ci : add caching for ROCm installation in release workflow (#15924) This commit applies the same caching to the release workflow which currently exists for the main CI workflow that was introduced in Commit ff02caf9eed261423289d1531a56536fbf57bfc2 ("ci : cache ROCm installation in windows-latest-cmake-hip (#15887)"). --- .github/workflows/release.yml | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5367637e42..701811eeb2 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -544,13 +544,23 @@ jobs: run: | git clone https://github.com/rocm/rocwmma --branch rocm-6.2.4 --depth 1 + - name: Cache ROCm Installation + id: cache-rocm + uses: actions/cache@v4 + with: + path: C:\Program Files\AMD\ROCm + key: rocm-6.1-${{ runner.os }}-v1 + restore-keys: | + rocm-6.1-${{ runner.os }}- + - name: ccache uses: ggml-org/ccache-action@v1.2.16 with: key: windows-latest-cmake-hip-${{ matrix.name }}-x64 evict-old-files: 1d - - name: Install + - name: Install ROCm + if: steps.cache-rocm.outputs.cache-hit != 'true' id: depends run: | $ErrorActionPreference = "Stop" @@ -558,13 +568,28 @@ jobs: Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" write-host "Installing AMD HIP SDK" $proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru - $proc.WaitForExit(600000) + $completed = $proc.WaitForExit(600000) + if (-not $completed) { + Write-Error "ROCm installation timed out after 10 minutes. Killing the process" + $proc.Kill() + exit 1 + } + if ($proc.ExitCode -ne 0) { + Write-Error "ROCm installation failed with exit code $($proc.ExitCode)" + exit 1 + } write-host "Completed AMD HIP SDK installation" - name: Verify ROCm id: verify run: | - & 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version + # Find and test ROCm installation + $clangPath = Get-ChildItem 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | Select-Object -First 1 + if (-not $clangPath) { + Write-Error "ROCm installation not found" + exit 1 + } + & $clangPath.FullName --version - name: Build id: cmake_build From 0f0a3c2851134d49955f3c85afbb0b1bb47c3e07 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 10 Sep 2025 17:52:35 +0300 Subject: [PATCH 42/46] metal : make the backend async (#15906) * metal : make the backend async ggml-ci * cont : add comments, extend op offload, clean up ggml-ci * metal : fix batch size for MUL_MAT_ID * metal : remove deprecated ggml_backend_metal_buffer_from_ptr * metal : create only metal buffers, no wrapping of host memory ggml-ci * metal : restore .alloc_buffer for buffer_from_ptr_type ggml-ci * metal : remove broken implementation of GGML_OP_SET ggml-ci * metal : clean-up loose ends, ready for tests ggml-ci * metal : support both private and shared buffers ggml-ci * metal : enable private buffers + add global device queue * metal : disable host buffer to prevent races ggml-ci * metal : avoid extra copy during set_tensor ggml-ci * metal : use separate buffer types for shread and private Metal buffers ggml-ci * metal : simplify synchronization logic ggml-ci * metal : fix build ggml-ci * metal : do not implement cpy_tensor ggml-ci * metal : separate implementations for shared and private buffers ggml-ci --- ggml/include/ggml-metal.h | 6 - ggml/src/ggml-metal/ggml-metal.m | 975 ++++++++++++++++++--------- ggml/src/ggml-metal/ggml-metal.metal | 32 - 3 files changed, 674 insertions(+), 339 deletions(-) diff --git a/ggml/include/ggml-metal.h b/ggml/include/ggml-metal.h index a610694423..1163438bc2 100644 --- a/ggml/include/ggml-metal.h +++ b/ggml/include/ggml-metal.h @@ -43,14 +43,8 @@ GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void); GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend); -GGML_DEPRECATED( - GGML_BACKEND_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size), - "obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713"); - GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data); -GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void); - // helper to check if the device supports a specific family // ideally, the user code should be doing these checks // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index e76fb71263..07b96dbdd5 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -48,6 +48,11 @@ static struct ggml_backend_metal_device_context { int mtl_device_ref_count; id mtl_library; + // a single global queue shared by all Metal backends + // technically not needed for devices with unified memory, but enables discrete GPUs support + // ref: https://github.com/ggml-org/llama.cpp/pull/15906 + id mtl_queue; + NSLock * mtl_lock; bool has_simdgroup_reduction; @@ -56,6 +61,7 @@ static struct ggml_backend_metal_device_context { bool has_bfloat; bool use_bfloat; bool use_fusion; + bool use_shared_buffers; int debug_fusion; @@ -69,6 +75,7 @@ static struct ggml_backend_metal_device_context { /*.mtl_device =*/ nil, /*.mtl_device_ref_count =*/ 0, /*.mtl_library =*/ nil, + /*.mtl_queue =*/ nil, /*.mtl_lock =*/ nil, /*.has_simdgroup_reduction =*/ false, /*.has_simdgroup_mm =*/ false, @@ -76,6 +83,7 @@ static struct ggml_backend_metal_device_context { /*.has_bfloat =*/ false, /*.use_bfloat =*/ false, /*.use_fusion =*/ true, + /*.use_shared_buffers =*/ true, /*.debug_fusion =*/ 0, /*.fuse_cnt =*/ { 0 }, /*.max_size =*/ 0, @@ -94,6 +102,11 @@ static id ggml_backend_metal_device_acq(struct ggml_backend_metal_dev ctx->mtl_device = MTLCreateSystemDefaultDevice(); if (ctx->mtl_device) { + ctx->mtl_queue = [ctx->mtl_device newCommandQueue]; + if (ctx->mtl_queue == nil) { + GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); + } + ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; @@ -118,6 +131,12 @@ static id ggml_backend_metal_device_acq(struct ggml_backend_metal_dev ctx->debug_fusion = val ? atoi(val) : 0; } + ctx->use_shared_buffers = ctx->mtl_device.hasUnifiedMemory; + + if (getenv("GGML_METAL_SHARED_BUFFERS_DISABLE") != NULL) { + ctx->use_shared_buffers = false; + } + memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt)); ctx->max_size = ctx->mtl_device.maxBufferLength; @@ -161,6 +180,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte ctx->mtl_library = nil; } + if (ctx->mtl_queue) { + [ctx->mtl_queue release]; + ctx->mtl_queue = nil; + } + if (ctx->mtl_device) { [ctx->mtl_device release]; ctx->mtl_device = nil; @@ -467,8 +491,6 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, - GGML_METAL_KERNEL_TYPE_SET_I32, - GGML_METAL_KERNEL_TYPE_SET_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, @@ -773,7 +795,7 @@ struct ggml_metal_command_buffer { struct ggml_backend_metal_context { id device; - id queue; + id queue; // currently a pointer to the device queue, but might become separate queue [TAG_QUEUE_PER_BACKEND] dispatch_queue_t d_queue; @@ -803,6 +825,12 @@ struct ggml_backend_metal_context { // n_cb command buffers + 1 used by the main thread struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; + // extra command buffers for things like getting, setting and copying tensors + NSMutableArray * cmd_bufs_ext; + + // the last command buffer queued into the Metal queue with operations relevant to the current Metal backend + id cmd_buf_last; + // abort ggml_metal_graph_compute if callback returns true ggml_abort_callback abort_callback; void * abort_callback_data; @@ -999,7 +1027,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]); ctx->device = device; - ctx->queue = [device newCommandQueue]; + + // TODO: question - would it be better to have one queue for the backend and one queue for the device? + // the graph encoders and async ops would use the backend queue while the sync ops would use the device queue? + //ctx->queue = [device newCommandQueue]; [TAG_QUEUE_PER_BACKEND] + ctx->queue = ctx_dev->mtl_queue; if (ctx->queue == nil) { GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); return NULL; @@ -1058,6 +1090,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_LOG_INFO("%s: has residency sets = %s\n", __func__, ctx_dev->has_residency_sets ? "true" : "false"); GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false"); GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false"); + GGML_LOG_INFO("%s: use fusion = %s\n", __func__, ctx_dev->use_fusion ? "true" : "false"); + GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, ctx_dev->use_shared_buffers ? "true" : "false"); GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false"); ctx->capture_next_compute = false; @@ -1073,6 +1107,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de ctx->cmd_bufs[i].mem_pool->device = device; } + ctx->cmd_bufs_ext = [[NSMutableArray alloc] init]; + + ctx->cmd_buf_last = nil; + #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) if (@available(macOS 10.12, iOS 16.0, *)) { GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6); @@ -1390,8 +1428,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat); @@ -1663,14 +1699,19 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { Block_release(ctx->encode_async); - [ctx->queue release]; + //[ctx->queue release]; // [TAG_QUEUE_PER_BACKEND] for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { - // ctx->cmd_bufs[i].obj is auto released + if (ctx->cmd_bufs[i].obj) { + [ctx->cmd_bufs[i].obj release]; + } ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool); } + [ctx->cmd_bufs_ext removeAllObjects]; + [ctx->cmd_bufs_ext release]; + dispatch_release(ctx->d_queue); free(ctx); @@ -1688,14 +1729,21 @@ struct ggml_backend_metal_buffer { struct ggml_backend_metal_buffer_context { void * all_data; size_t all_size; - bool owned; + + // if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host + bool is_shared; // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap int n_buffers; struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; // optional MTLResidencySet + // note: cannot use explicity "id" here because it is not available on certain OSes id rset; + + // pointers to global device objects + id device; + id queue; }; // rset init @@ -1761,7 +1809,7 @@ static void ggml_backend_metal_buffer_rset_free(struct ggml_backend_metal_buffer // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the // Metal buffer based on the host memory pointer // -static id ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs) { +static id ggml_metal_get_buffer(const struct ggml_tensor * t, size_t * offs) { //GGML_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach); const int64_t tsize = ggml_nbytes(t); @@ -1984,16 +2032,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex return false; }; } - case GGML_OP_SET: - { - switch (op->src[0]->type) { - case GGML_TYPE_F32: - case GGML_TYPE_I32: - return true; - default: - return false; - }; - } case GGML_OP_DIAG_MASK_INF: case GGML_OP_GET_ROWS: { @@ -5569,68 +5607,6 @@ static int ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)]; } break; - case GGML_OP_SET: - { - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); - - // src0 and dst as viewed during set - const size_t dst_nb0 = ggml_element_size(src0); - - const size_t dst_nb1 = ((int32_t *) dst->op_params)[0]; - const size_t dst_nb2 = ((int32_t *) dst->op_params)[1]; - const size_t dst_nb3 = ((int32_t *) dst->op_params)[2]; - const size_t offset = ((int32_t *) dst->op_params)[3]; - const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - - if (!inplace) { - memcpy(((char *) dst->data), ((char *) src0->data), ggml_nbytes(dst)); - } - - const int im0 = (ne10 == 0 ? 0 : ne10-1); - const int im1 = (ne11 == 0 ? 0 : ne11-1); - const int im2 = (ne12 == 0 ? 0 : ne12-1); - const int im3 = (ne13 == 0 ? 0 : ne13-1); - - GGML_ASSERT(offset + im0*dst_nb0 + im1*dst_nb1 + im2*dst_nb2 + im3*dst_nb3 <= ggml_nbytes(dst)); - - id pipeline = nil; - - switch (src0t) { - case GGML_TYPE_F32: - GGML_ASSERT(nb10 == sizeof(float)); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_F32].pipeline; break; - case GGML_TYPE_I32: - GGML_ASSERT(nb10 == sizeof(int32_t)); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_I32].pipeline; break; - default: GGML_ABORT("fatal error"); - } - - ggml_metal_kargs_set args = { - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.nb1 =*/ dst_nb1, - /*.nb2 =*/ dst_nb2, - /*.nb3 =*/ dst_nb3, - /*.offs =*/ offset, - /*.inplace =*/ inplace, - }; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne10); - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; case GGML_OP_POOL_2D: { GGML_ASSERT(ggml_is_contiguous(src0)); @@ -5759,6 +5735,12 @@ static enum ggml_status ggml_metal_graph_compute( if (should_capture) { ctx->capture_next_compute = false; + // make sure all previous computations have finished before starting the capture + if (ctx->cmd_buf_last) { + [ctx->cmd_buf_last waitUntilCompleted]; + ctx->cmd_buf_last = nil; + } + if (!ctx->capture_started) { // create capture scope ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device]; @@ -5781,78 +5763,103 @@ static enum ggml_status ggml_metal_graph_compute( // the main thread commits the first few commands immediately // cmd_buf[n_cb] { - id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + // cannot use commandBufferWithUnretainedReferences because the buffers from the memory pool can get destroyed + // TODO: when the memory pools are removed, we can again use commandBufferWithUnretainedReferences + // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2334215009 + //id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + id cmd_buf = [ctx->queue commandBuffer]; + [cmd_buf retain]; + ctx->cmd_bufs[n_cb].obj = cmd_buf; [cmd_buf enqueue]; + ctx->encode_async(n_cb); } - // prepare the rest of the command buffers asynchronously + // remember the command buffer for the next iteration + ctx->cmd_buf_last = ctx->cmd_bufs[n_cb].obj; + + // prepare the rest of the command buffers asynchronously (optional) // cmd_buf[0.. n_cb) for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + //id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + id cmd_buf = [ctx->queue commandBuffer]; + [cmd_buf retain]; + + if (ctx->cmd_bufs[cb_idx].obj) { + [ctx->cmd_bufs[cb_idx].obj release]; + } ctx->cmd_bufs[cb_idx].obj = cmd_buf; // always enqueue the first two command buffers // enqueue all of the command buffers if we don't need to abort if (cb_idx < 2 || ctx->abort_callback == NULL) { [cmd_buf enqueue]; + + // update the pointer to the last queued command buffer + // this is needed to implement synchronize() + ctx->cmd_buf_last = cmd_buf; } } dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async); - // wait for completion and check status of each command buffer - // needed to detect if the device ran out-of-memory for example (#1881) - { - id cmd_buf = ctx->cmd_bufs[n_cb].obj; - [cmd_buf waitUntilCompleted]; - - MTLCommandBufferStatus status = [cmd_buf status]; - if (status != MTLCommandBufferStatusCompleted) { - GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); - if (status == MTLCommandBufferStatusError) { - GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); - } - - return GGML_STATUS_FAILED; - } - } - - for (int i = 0; i < n_cb; ++i) { - id cmd_buf = ctx->cmd_bufs[i].obj; - [cmd_buf waitUntilCompleted]; - - MTLCommandBufferStatus status = [cmd_buf status]; - if (status != MTLCommandBufferStatusCompleted) { - GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); - if (status == MTLCommandBufferStatusError) { - GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); - } - - return GGML_STATUS_FAILED; - } - - id next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil); - if (!next_buffer) { - continue; - } - - const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); - if (next_queued) { - continue; - } - - if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { - GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i); - return GGML_STATUS_ABORTED; - } - - [next_buffer commit]; - } + // for debugging: block until graph is computed + //[ctx->cmd_buf_last waitUntilCompleted]; + // enter here only when capturing in order to wait for all computation to finish + // otherwise, we leave the graph to compute asynchronously if (!should_capture && ctx->capture_started) { + // wait for completion and check status of each command buffer + // needed to detect if the device ran out-of-memory for example (#1881) + { + id cmd_buf = ctx->cmd_bufs[n_cb].obj; + [cmd_buf waitUntilCompleted]; + + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; + } + } + + for (int i = 0; i < n_cb; ++i) { + id cmd_buf = ctx->cmd_bufs[i].obj; + [cmd_buf waitUntilCompleted]; + + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; + } + + id next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil); + if (!next_buffer) { + continue; + } + + const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); + if (next_queued) { + continue; + } + + if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { + GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i); + return GGML_STATUS_ABORTED; + } + + [next_buffer commit]; + } + [ctx->capture_scope endScope]; [[MTLCaptureManager sharedCaptureManager] stopCapture]; } @@ -5862,10 +5869,12 @@ static enum ggml_status ggml_metal_graph_compute( } //////////////////////////////////////////////////////////////////////////////// - // backend interface +//////////////////////////////////////////////////////////////////////////////// -static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) { +// shared buffer + +static void ggml_backend_metal_buffer_shared_free_buffer(ggml_backend_buffer_t buffer) { struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; for (int i = 0; i < ctx->n_buffers; i++) { @@ -5874,7 +5883,9 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) ggml_backend_metal_buffer_rset_free(ctx); - if (ctx->owned) { + GGML_ASSERT(ctx->is_shared); + + { #if TARGET_OS_OSX vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size); #else @@ -5885,66 +5896,254 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) free(ctx); } -static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) { +static void * ggml_backend_metal_buffer_shared_get_base(ggml_backend_buffer_t buffer) { struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; return ctx->all_data; } -static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { - memset((char *)tensor->data + offset, value, size); - - GGML_UNUSED(buffer); -} - -static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - memcpy((char *)tensor->data + offset, data, size); - - GGML_UNUSED(buffer); -} - -static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { - memcpy(data, (const char *)tensor->data + offset, size); - - GGML_UNUSED(buffer); -} - -static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { - if (ggml_backend_buffer_is_host(src->buffer)) { - memcpy(dst->data, src->data, ggml_nbytes(src)); - return true; - } - return false; - - GGML_UNUSED(buffer); -} - -static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { +static void ggml_backend_metal_buffer_shared_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + GGML_ASSERT(ctx->is_shared); + + memset((char *)tensor->data + offset, value, size); +} + +static void ggml_backend_metal_buffer_shared_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + GGML_ASSERT(ctx->is_shared); + + memcpy((char *)tensor->data + offset, data, size); +} + +static void ggml_backend_metal_buffer_shared_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + GGML_ASSERT(ctx->is_shared); + + memcpy(data, (const char *)tensor->data + offset, size); +} + +static bool ggml_backend_metal_buffer_shared_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { + GGML_UNUSED(buffer); + GGML_UNUSED(src); + GGML_UNUSED(dst); + + return false; +} + +static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer, uint8_t value) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + GGML_ASSERT(ctx->is_shared); + memset(ctx->all_data, value, ctx->all_size); } -static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = { - /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer, - /* .get_base = */ ggml_backend_metal_buffer_get_base, +static struct ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = { + /* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer, + /* .get_base = */ ggml_backend_metal_buffer_shared_get_base, /* .init_tensor = */ NULL, - /* .memset_tensor = */ ggml_backend_metal_buffer_memset_tensor, - /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor, - /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor, - /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor, - /* .clear = */ ggml_backend_metal_buffer_clear, + /* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor, + /* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor, + /* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor, + /* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor, + /* .clear = */ ggml_backend_metal_buffer_shared_clear, /* .reset = */ NULL, }; -// default buffer type +// private buffer -static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) { - return "Metal"; +static void ggml_backend_metal_buffer_private_free_buffer(ggml_backend_buffer_t buffer) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; - GGML_UNUSED(buft); + for (int i = 0; i < ctx->n_buffers; i++) { + [ctx->buffers[i].metal release]; + } + + ggml_backend_metal_buffer_rset_free(ctx); + + GGML_ASSERT(!ctx->is_shared); + + free(ctx); } +static void * ggml_backend_metal_buffer_private_get_base(ggml_backend_buffer_t buffer) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + return ctx->all_data; +} + +static void ggml_backend_metal_buffer_private_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + GGML_ASSERT(!ctx->is_shared); + + @autoreleasepool { + // dst + size_t buf_dst_offset = 0; + id buf_dst = ggml_metal_get_buffer(tensor, &buf_dst_offset); + + buf_dst_offset += offset; + + id queue = ctx->queue; + id cmd_buf = [queue commandBufferWithUnretainedReferences]; + + { + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder fillBuffer:buf_dst + range:NSMakeRange(buf_dst_offset, buf_dst_offset + size) + value:value]; + + [encoder endEncoding]; + } + + [cmd_buf commit]; + [cmd_buf waitUntilCompleted]; + } +} + +static void ggml_backend_metal_buffer_private_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + GGML_ASSERT(!ctx->is_shared); + + @autoreleasepool { + // src + void * data_ptr = (void *)(uintptr_t) data; // "const cast" the src data + id buf_src = [ctx->device newBufferWithBytesNoCopy:data_ptr + length:size + options:MTLResourceStorageModeShared + deallocator:nil]; + + // dst + size_t buf_dst_offset = 0; + id buf_dst = ggml_metal_get_buffer(tensor, &buf_dst_offset); + + buf_dst_offset += offset; + + // note: for experimentation purposes, here we use a semaphore to wait for the copy to complete + // this is alternative to waitUntilCompleted, which should be faster, but don't seem to make much difference + dispatch_semaphore_t completion_semaphore = dispatch_semaphore_create(0); + + id queue = ctx->queue; + id cmd_buf = [queue commandBufferWithUnretainedReferences]; + + { + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:buf_src + sourceOffset:0 + toBuffer:buf_dst + destinationOffset:buf_dst_offset + size:size]; + + [encoder endEncoding]; + } + + [cmd_buf addCompletedHandler:^(id cb) { + // TODO: can check for errors here + GGML_UNUSED(cb); + + dispatch_semaphore_signal(completion_semaphore); + }]; + + [cmd_buf commit]; + + dispatch_semaphore_wait(completion_semaphore, DISPATCH_TIME_FOREVER); + //[cmd_buf waitUntilCompleted]; + } +} + +static void ggml_backend_metal_buffer_private_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + GGML_ASSERT(!ctx->is_shared); + + @autoreleasepool { + // src + size_t buf_src_offset = 0; + id buf_src = ggml_metal_get_buffer(tensor, &buf_src_offset); + + buf_src_offset += offset; + + // dst + id buf_dst = [ctx->device newBufferWithBytesNoCopy:data + length:size + options:MTLResourceStorageModeShared + deallocator:nil]; + + id queue = ctx->queue; + id cmd_buf = [queue commandBufferWithUnretainedReferences]; + + { + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:buf_src + sourceOffset:buf_src_offset + toBuffer:buf_dst + destinationOffset:0 + size:size]; + + [encoder endEncoding]; + } + + [cmd_buf commit]; + [cmd_buf waitUntilCompleted]; + } +} + +static bool ggml_backend_metal_buffer_private_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { + GGML_UNUSED(buffer); + GGML_UNUSED(src); + GGML_UNUSED(dst); + + return false; +} + +static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer, uint8_t value) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + GGML_ASSERT(!ctx->is_shared); + + @autoreleasepool { + id queue = ctx->queue; + id cmd_buf = [queue commandBufferWithUnretainedReferences]; + + { + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder fillBuffer:ctx->buffers[0].metal + range:NSMakeRange(0, ctx->buffers[0].size) + value:value]; + + [encoder endEncoding]; + } + + [cmd_buf commit]; + [cmd_buf waitUntilCompleted]; + } +} + +static struct ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = { + /* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer, + /* .get_base = */ ggml_backend_metal_buffer_private_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor, + /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor, + /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor, + /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor, + /* .clear = */ ggml_backend_metal_buffer_private_clear, + /* .reset = */ NULL, +}; + +// +// buffer types +// + static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) { #ifndef GGML_METAL_NDEBUG #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) @@ -5970,7 +6169,8 @@ static void ggml_backend_metal_log_allocated_size(id device, size_t s GGML_UNUSED(size_aligned); } -static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { +// common method for allocating shread or private Metal buffers +static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size, bool shared) { struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context)); const size_t size_page = sysconf(_SC_PAGESIZE); @@ -5986,22 +6186,40 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba id device = ctx_dev->mtl_device; - ctx->all_data = ggml_metal_host_malloc(size_aligned); + // allocate shared buffer if the device supports it and it is required by the buffer type + if (ctx_dev->use_shared_buffers && shared) { + ctx->all_data = ggml_metal_host_malloc(size_aligned); + ctx->is_shared = true; + } else { + // dummy, non-NULL value - we'll populate this after creating the Metal buffer below + ctx->all_data = (void *) 0x000000400ULL; + ctx->is_shared = false; + } ctx->all_size = size_aligned; - ctx->owned = true; + + ctx->device = device; + ctx->queue = ctx_dev->mtl_queue; + ctx->n_buffers = 1; if (ctx->all_data != NULL) { - ctx->buffers[0].data = ctx->all_data; ctx->buffers[0].size = size; ctx->buffers[0].metal = nil; if (size_aligned > 0) { - ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data - length:size_aligned - options:MTLResourceStorageModeShared - deallocator:nil]; + if (ctx_dev->use_shared_buffers) { + ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data + length:size_aligned + options:MTLResourceStorageModeShared + deallocator:nil]; + } else { + ctx->buffers[0].metal = [device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate]; + + ctx->all_data = (void *) (ctx->buffers[0].metal.gpuAddress); + } } + + ctx->buffers[0].data = ctx->all_data; } if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) { @@ -6018,36 +6236,50 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba //ggml_backend_metal_log_allocated_size(device, size_aligned); - return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size); + struct ggml_backend_buffer_i buf_i = ctx->is_shared ? ggml_backend_metal_buffer_shared_i : ggml_backend_metal_buffer_private_i; + + return ggml_backend_buffer_init(buft, buf_i, ctx, size); } -static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { +// default (shared) buffer type + +static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) { + return "Metal"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_metal_buffer_type_shared_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true); +} + +static size_t ggml_backend_metal_buffer_type_shared_get_alignment(ggml_backend_buffer_type_t buft) { return 32; GGML_UNUSED(buft); } -static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { +static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_buffer_type_t buft) { const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size; return max_size; } -static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) { - return true; +static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_type_t buft) { + return false; GGML_UNUSED(buft); } -ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) { +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(void) { static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = { /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_type_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size, + /* .get_name = */ ggml_backend_metal_buffer_type_shared_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size, /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .is_host = */ ggml_backend_metal_buffer_type_is_host, + /* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host, }, /* .device = */ &g_ggml_backend_metal_device, /* .context = */ NULL, @@ -6056,116 +6288,101 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) { return &ggml_backend_buffer_type_metal; } -static const char * ggml_backend_metal_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) { - return "Metal_Mapped"; +// default (private) buffer type + +static const char * ggml_backend_metal_buffer_type_private_get_name(ggml_backend_buffer_type_t buft) { + return "Metal_Private"; GGML_UNUSED(buft); } -static ggml_backend_buffer_type_t ggml_backend_metal_buffer_from_ptr_type(void) { - static struct ggml_backend_buffer_type ggml_backend_buffer_from_ptr_type_metal = { +static ggml_backend_buffer_t ggml_backend_metal_buffer_type_private_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, false); +} + +static size_t ggml_backend_metal_buffer_type_private_get_alignment(ggml_backend_buffer_type_t buft) { + return 32; + + GGML_UNUSED(buft); +} + +static size_t ggml_backend_metal_buffer_type_private_get_max_size(ggml_backend_buffer_type_t buft) { + const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size; + + return max_size; +} + +static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_type_t buft) { + return false; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(void) { + static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = { /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_from_ptr_type_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size, + /* .get_name = */ ggml_backend_metal_buffer_type_private_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size, /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .is_host = */ ggml_backend_metal_buffer_type_is_host, + /* .is_host = */ ggml_backend_metal_buffer_type_private_is_host, }, /* .device = */ &g_ggml_backend_metal_device, /* .context = */ NULL, }; - return &ggml_backend_buffer_from_ptr_type_metal; + return &ggml_backend_buffer_type_metal; } -// TODO: obsoleted by ggml_backend_metal_device_buffer_from_ptr -ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) { - struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context)); +// mapped buffer type - ctx->all_data = data; - ctx->all_size = size; - ctx->owned = false; - ctx->n_buffers = 0; +static const char * ggml_backend_metal_buffer_type_mapped_get_name(ggml_backend_buffer_type_t buft) { + return "Metal_Mapped"; - const size_t size_page = sysconf(_SC_PAGESIZE); + GGML_UNUSED(buft); +} - // page-align the data ptr - { - const uintptr_t offs = (uintptr_t) data % size_page; - data = (void *) ((char *) data - offs); - size += offs; - } +static ggml_backend_buffer_t ggml_backend_metal_buffer_type_mapped_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + // for mapped buffers, prefer shared memory + return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true); +} - size_t size_aligned = size; - if ((size_aligned % size_page) != 0) { - size_aligned += (size_page - (size_aligned % size_page)); - } +static size_t ggml_backend_metal_buffer_type_mapped_get_alignment(ggml_backend_buffer_type_t buft) { + return 32; - struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main; + GGML_UNUSED(buft); +} - GGML_ASSERT(ctx_dev->mtl_device != nil); +static size_t ggml_backend_metal_buffer_type_mapped_get_max_size(ggml_backend_buffer_type_t buft) { + const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size; - id device = ctx_dev->mtl_device; + return max_size; +} - // the buffer fits into the max buffer size allowed by the device - if (size_aligned <= device.maxBufferLength) { - ctx->buffers[ctx->n_buffers].data = data; - ctx->buffers[ctx->n_buffers].size = size; - ctx->buffers[ctx->n_buffers].metal = nil; +static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_type_t buft) { + return false; - if (size_aligned > 0) { - ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; + GGML_UNUSED(buft); +} - if (ctx->buffers[ctx->n_buffers].metal == nil) { - GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); - return false; - } - } +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(void) { + // note: not obvious, but this buffer type still needs to implement .alloc_buffer: + // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099 + static struct ggml_backend_buffer_type ggml_backend_buffer_type_mapped_metal = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_mapped_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size, + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host, + }, + /* .device = */ &g_ggml_backend_metal_device, + /* .context = */ NULL, + }; - ggml_backend_metal_log_allocated_size(device, size_aligned); - - ++ctx->n_buffers; - } else { - // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into - // one of the views - const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case - const size_t size_step = device.maxBufferLength - size_ovlp; - const size_t size_view = device.maxBufferLength; - - for (size_t i = 0; i < size; i += size_step) { - const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); - - ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i); - ctx->buffers[ctx->n_buffers].size = size_step_aligned; - ctx->buffers[ctx->n_buffers].metal = nil; - - if (size_step_aligned > 0) { - ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; - - if (ctx->buffers[ctx->n_buffers].metal == nil) { - GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0); - return false; - } - } - - ggml_backend_metal_log_allocated_size(device, size_step_aligned); - - if (i + size_step < size) { - GGML_LOG_INFO("\n"); - } - - ++ctx->n_buffers; - } - } - - if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { - GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); - free(ctx); - return NULL; - } - - return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size); + return &ggml_backend_buffer_type_mapped_metal; } // backend @@ -6184,6 +6401,137 @@ static void ggml_backend_metal_free(ggml_backend_t backend) { free(backend); } +static void ggml_backend_metal_synchronize(ggml_backend_t backend) { + struct ggml_backend_metal_context * ctx = backend->context; + + // wait for any backend operations to finish + if (ctx->cmd_buf_last) { + [ctx->cmd_buf_last waitUntilCompleted]; + ctx->cmd_buf_last = nil; + } + + // release any completed command buffers + if (ctx->cmd_bufs_ext.count > 0) { + for (size_t i = 0; i < ctx->cmd_bufs_ext.count; ++i) { + id cmd_buf = ctx->cmd_bufs_ext[i]; + + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, (int) i, (int) status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + GGML_ABORT("fatal error"); + } + + [cmd_buf release]; + } + + [ctx->cmd_bufs_ext removeAllObjects]; + } +} + +static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + struct ggml_backend_metal_context * ctx = backend->context; + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + @autoreleasepool { + id device = ctx_dev->mtl_device; + + // wrap the source data into a Metal buffer + id buf_src = [device newBufferWithBytes:data + length:size + options:MTLResourceStorageModeShared]; + + size_t buf_dst_offset = 0; + id buf_dst = ggml_metal_get_buffer(tensor, &buf_dst_offset); + + if (buf_dst == nil) { + GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name); + } + + buf_dst_offset += offset; + + // queue the copy operation into the queue of the Metal context + // this will be queued at the end, after any currently ongoing GPU operations + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:buf_src + sourceOffset:0 + toBuffer:buf_dst + destinationOffset:buf_dst_offset + size:size]; + + [encoder endEncoding]; + [cmd_buf commit]; + + // do not wait here for completion + //[cmd_buf waitUntilCompleted]; + + // instead, remember a reference to the command buffer and wait for it later if needed + [ctx->cmd_bufs_ext addObject:cmd_buf]; + ctx->cmd_buf_last = cmd_buf; + + [cmd_buf retain]; + } +} + +static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + struct ggml_backend_metal_context * ctx = backend->context; + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + @autoreleasepool { + id device = ctx_dev->mtl_device; + + id buf_dst = [device newBufferWithBytesNoCopy:data + length:size + options:MTLResourceStorageModeShared + deallocator:nil]; + + size_t buf_src_offset = 0; + id buf_src = ggml_metal_get_buffer(tensor, &buf_src_offset); + + if (buf_src == nil) { + GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name); + } + + buf_src_offset += offset; + + // queue the copy operation into the queue of the Metal context + // this will be queued at the end, after any currently ongoing GPU operations + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:buf_src + sourceOffset:buf_src_offset + toBuffer:buf_dst + destinationOffset:0 + size:size]; + + [encoder endEncoding]; + [cmd_buf commit]; + + // do not wait here for completion + //[cmd_buf waitUntilCompleted]; + + // instead, remember a reference to the command buffer and wait for it later if needed + [ctx->cmd_bufs_ext addObject:cmd_buf]; + ctx->cmd_buf_last = cmd_buf; + + [cmd_buf retain]; + } +} + +static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) { + return false; + + GGML_UNUSED(backend_src); + GGML_UNUSED(backend_dst); + GGML_UNUSED(src); + GGML_UNUSED(dst); +} + static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { return ggml_metal_graph_compute(backend, cgraph); } @@ -6214,7 +6562,10 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { const int n_nodes_per_cb = ctx->n_nodes_per_cb; - id cmd_buf = ctx->cmd_bufs[cb_idx].obj; + id cmd_buf = ctx->cmd_bufs[cb_idx].obj; + struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool; + + ggml_metal_mem_pool_reset(mem_pool); id encoder = [cmd_buf computeCommandEncoder]; @@ -6228,9 +6579,6 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { const bool should_capture = ctx->capture_next_compute; - struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool; - ggml_metal_mem_pool_reset(mem_pool); - for (int idx = node_start; idx < node_end;) { if (should_capture) { [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; @@ -6264,15 +6612,19 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { static struct ggml_backend_i ggml_backend_metal_i = { /* .get_name = */ ggml_backend_metal_name, /* .free = */ ggml_backend_metal_free, - /* .set_tensor_async = */ NULL, - /* .get_tensor_async = */ NULL, - /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, + /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async, + /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async, + /* .cpy_tensor_async = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups + /* .synchronize = */ ggml_backend_metal_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_metal_graph_compute, + + // the events API is needed only for multi-GPU setups, so likely no need to implement it for Metal + // in any case, these docs seem relevant if we ever decide to implement it: + // https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events /* .event_record = */ NULL, /* .event_wait = */ NULL, /* .optimize_graph = */ NULL, @@ -6376,7 +6728,7 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct g props->type = ggml_backend_metal_device_get_type(dev); ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); props->caps = (struct ggml_backend_dev_caps) { - /* .async = */ false, + /* .async = */ true, /* .host_buffer = */ false, /* .buffer_from_host_ptr = */ true, /* .events = */ false, @@ -6407,17 +6759,19 @@ static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, con } static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) { - return ggml_backend_metal_buffer_type(); + struct ggml_backend_metal_device_context * ctx_dev = dev->context; - GGML_UNUSED(dev); + return ctx_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared() : ggml_backend_metal_buffer_type_private(); } -static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { +static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context)); ctx->all_data = ptr; ctx->all_size = size; - ctx->owned = false; + + ctx->is_shared = true; + ctx->n_buffers = 0; const size_t size_page = sysconf(_SC_PAGESIZE); @@ -6440,6 +6794,9 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back id device = ctx_dev->mtl_device; + ctx->device = device; + ctx->queue = ctx_dev->mtl_queue; + // the buffer fits into the max buffer size allowed by the device if (size_aligned <= device.maxBufferLength) { ctx->buffers[ctx->n_buffers].data = ptr; @@ -6497,7 +6854,7 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back return NULL; } - return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size); + return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(), ggml_backend_metal_buffer_shared_i, ctx, size); } static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { @@ -6508,14 +6865,30 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { return - buft->iface.get_name == ggml_backend_metal_buffer_type_get_name || - buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name; + buft->iface.get_name == ggml_backend_metal_buffer_type_shared_get_name || + buft->iface.get_name == ggml_backend_metal_buffer_type_private_get_name || + buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name; GGML_UNUSED(dev); } +static int64_t get_op_batch_size(const struct ggml_tensor * op) { + switch (op->op) { + case GGML_OP_MUL_MAT: + return op->ne[1]; + case GGML_OP_MUL_MAT_ID: + return op->ne[2]; + default: + return ggml_nrows(op); + } +} + static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { - return false; + const int min_batch_size = 32; + + return (op->op == GGML_OP_MUL_MAT || + op->op == GGML_OP_MUL_MAT_ID) && + get_op_batch_size(op) >= min_batch_size; GGML_UNUSED(dev); GGML_UNUSED(op); @@ -6530,7 +6903,7 @@ static struct ggml_backend_device_i ggml_backend_metal_device_i = { /* .init_backend = */ ggml_backend_metal_device_init, /* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type, /* .get_host_buffer_type = */ NULL, - /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_from_ptr, + /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_mapped, /* .supports_op = */ ggml_backend_metal_device_supports_op, /* .supports_buft = */ ggml_backend_metal_device_supports_buft, /* .offload_op = */ ggml_backend_metal_device_offload_op, diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 77be3c5c9d..157d0cc6d0 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -5571,38 +5571,6 @@ kernel void kernel_flash_attn_ext_vec_reduce( #undef DV } -template -kernel void kernel_set( - constant ggml_metal_kargs_set & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i13 = tgpig[2]; - const int i12 = tgpig[1]; - const int i11 = tgpig[0]; - - const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10; - - const int64_t i3 = n / (args.ne12*args.ne11*args.ne10); - const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10); - const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10; - - device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs); - - for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) { - device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10); - dst_data[i10] = (T) src[0]; - } -} - -typedef decltype(kernel_set) kernel_set_t; - -template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set; -template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set; - template kernel void kernel_cpy( constant ggml_metal_kargs_cpy & args, From 9de447d94e1ae9d1a36e5a2e5bf47483352c0d9c Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 10 Sep 2025 17:31:40 +0200 Subject: [PATCH 43/46] ggml-cpu : fix padding in ggml_timestep_embedding (#15917) This commit fixes the zero padding for odd dimensions in ggml_compute_forward_timestep_embedding_f32. The motivation for this is that currently if an odd dimension is used, the padding check incorrectly uses the dimension value for indexing. For example, with dim=15: Elements 0-6 are set to cosine values Elements 7-13 are set to sine values Element 14 is left uninitialized (contains garbage) Element 15 is correctly set to zero This fix changes embed_data[dim] to embed_data[2 * half] so that element 14 (the first unused element) is properly set to zero as well as the last element. Resolves: https://github.com/ggml-org/ggml/issues/1324 --- ggml/src/ggml-cpu/ops.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 9adf910769..212e52ef6a 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8598,6 +8598,7 @@ static void ggml_compute_forward_timestep_embedding_f32( embed_data[j + half] = sinf(arg); } if (dim % 2 != 0 && ith == 0) { + embed_data[2 * half] = 0.f; embed_data[dim] = 0.f; } } From 6ab397e12ba8e9f776341cdae68f7ffb2f8d2cde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Wed, 10 Sep 2025 19:08:59 +0200 Subject: [PATCH 44/46] graph : support non-contiguous Q in build_attn_mha (#15908) * support non-contiguous Q in build_attn_mha * Update src/llama-graph.cpp ggml-ci Co-authored-by: Georgi Gerganov --------- Co-authored-by: Georgi Gerganov --- src/llama-graph.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 7f254b25cd..ddc772b179 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1273,7 +1273,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( // split the batch into streams if needed const auto n_stream = k->ne[3]; - q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream); + q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0); q = ggml_permute(ctx0, q, 0, 2, 1, 3); k = ggml_permute(ctx0, k, 0, 2, 1, 3); From 4f658855fa8f2e42b7ed9a5b298fa39a2e39b096 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jie=20Fu=20=28=E5=82=85=E6=9D=B0=29?= Date: Thu, 11 Sep 2025 02:51:51 +0800 Subject: [PATCH 45/46] llama : support T5 models with unequal number of encoder-decoder layers (#15909) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Extend the support of T5 models with different encoder-decoder layers Signed-off-by: Jie Fu * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret * Update gguf-py/gguf/constants.py Co-authored-by: Sigbjørn Skjæret * Update gguf-py/gguf/gguf_writer.py Co-authored-by: Sigbjørn Skjæret * Update src/llama-arch.cpp Co-authored-by: Sigbjørn Skjæret * Update src/llama-arch.h Co-authored-by: Sigbjørn Skjæret * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret * Update src/llama-hparams.h Co-authored-by: Sigbjørn Skjæret * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret * Rename n_dec_layer --> dec_n_layer Signed-off-by: Jie Fu * Adapt to cases when dec_n_layer > n_layer Signed-off-by: Jie Fu --------- Signed-off-by: Jie Fu Co-authored-by: Sigbjørn Skjæret --- convert_hf_to_gguf.py | 2 ++ gguf-py/gguf/constants.py | 1 + gguf-py/gguf/gguf_writer.py | 3 +++ src/llama-arch.cpp | 1 + src/llama-arch.h | 1 + src/llama-hparams.h | 1 + src/llama-model.cpp | 26 ++++++++++++++++++++++---- 7 files changed, 31 insertions(+), 4 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 62a546ee22..bbc21813f8 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6701,6 +6701,8 @@ class T5Model(TextModel): self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"]) self.gguf_writer.add_block_count(self.hparams["num_layers"]) + if (dec_n_layer := self.hparams.get("num_decoder_layers")) is not None: + self.gguf_writer.add_decoder_block_count(dec_n_layer) self.gguf_writer.add_head_count(self.hparams["num_heads"]) self.gguf_writer.add_key_length(self.hparams["d_kv"]) self.gguf_writer.add_value_length(self.hparams["d_kv"]) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index b8ac394580..1e88b6505b 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -109,6 +109,7 @@ class Keys: POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" + DECODER_BLOCK_COUNT = "{arch}.decoder_block_count" ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping" FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping" SWIN_NORM = "{arch}.swin_norm" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a6cc8a931e..7ff12f7f57 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -676,6 +676,9 @@ class GGUFWriter: def add_decoder_start_token_id(self, id: int) -> None: self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id) + def add_decoder_block_count(self, value: int) -> None: + self.add_uint32(Keys.LLM.DECODER_BLOCK_COUNT.format(arch=self.arch), value) + def add_embedding_length_per_layer_input(self, value: int) -> None: self.add_uint32(Keys.LLM.EMBD_LENGTH_PER_LAYER_INP.format(arch=self.arch), value) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 77b2fecf18..81f9746818 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -137,6 +137,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, + { LLM_KV_DECODER_BLOCK_COUNT, "%s.decoder_block_count" }, { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, { LLM_KV_SWIN_NORM, "%s.swin_norm" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 21ab47bd7a..6ee3707dcf 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -141,6 +141,7 @@ enum llm_kv { LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, + LLM_KV_DECODER_BLOCK_COUNT, LLM_KV_ATTN_LOGIT_SOFTCAPPING, LLM_KV_FINAL_LOGIT_SOFTCAPPING, LLM_KV_SWIN_NORM, diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 89f5c7ab65..4dca2ca41d 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -159,6 +159,7 @@ struct llama_hparams { // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; + uint32_t dec_n_layer = 0; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b9e4634a70..818b209641 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1542,6 +1542,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.dec_start_token_id = dec_start_token_id; } + hparams.dec_n_layer = hparams.n_layer; + ml.get_key(LLM_KV_DECODER_BLOCK_COUNT, hparams.dec_n_layer, false); + switch (hparams.n_layer) { case 6: type = LLM_TYPE_60M; break; // t5-small case 8: type = LLM_TYPE_80M; break; // flan-t5-small @@ -4414,6 +4417,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } + // n_layer: number of encoder_layers + // dec_n_layer: number of decoder_layers + const int dec_n_layer = hparams.dec_n_layer; + if (dec_n_layer > n_layer) { + layers.resize(dec_n_layer); + } + + // load encoder layers for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -4429,6 +4440,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + + // load decoder layers + for (int i = 0; i < dec_n_layer; ++i) { + auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); @@ -13509,7 +13525,9 @@ struct llm_build_t5_dec : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + const int64_t dec_n_layer = hparams.dec_n_layer; + + for (int il = 0; il < dec_n_layer; ++il) { ggml_tensor * inpSA = inpL; // norm @@ -13600,7 +13618,7 @@ struct llm_build_t5_dec : public llm_graph_context { //cb(cur, "kqv_out", il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == dec_n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids); } @@ -13621,8 +13639,8 @@ struct llm_build_t5_dec : public llm_graph_context { model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, - model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, - model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, + model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_RELU, + model.layers[il].ffn_gate ? LLM_FFN_PAR : LLM_FFN_SEQ, il); cb(cur, "ffn_out", il); } From 00681dfc16ba4cebb9c7fbd2cf2656e06a0692a4 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Wed, 10 Sep 2025 22:04:03 +0200 Subject: [PATCH 46/46] CUDA: Add `fastdiv` to `k_bin_bcast*`, giving 1-3% E2E performance (#15872) * Add fastdiv and fastmodulo to k_bin_bcast kernel * Address review comments * `prod_` instead of `prod` suffix * Add test case for `k_bin_bcast_unravel` in CUDA backend --- ggml/src/ggml-cuda/binbcast.cu | 182 +++++++++++++++++++-------------- tests/test-backend-ops.cpp | 3 + 2 files changed, 111 insertions(+), 74 deletions(-) diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 1c76566344..725e1a81a1 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -23,28 +23,44 @@ static __device__ __forceinline__ float op_div(const float a, const float b) { return a / b; } +template +static __global__ void k_bin_bcast(const src0_t * src0, + const src1_t * src1, + dst_t * dst, + const int ne0, + const int ne1, + const int ne2, + const uint3 ne3, + const uint3 ne10, + const uint3 ne11, + const uint3 ne12, + const uint3 ne13, + /*int s0, */ const int s1, + const int s2, + const int s3, + /*int s00,*/ const int s01, + const int s02, + const int s03, + /*int s10,*/ const int s11, + const int s12, + const int s13, + src1_ptrs... src1s) { + const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x; + const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y); + const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3); + const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z); - -template -static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, - const int ne0, const int ne1, const int ne2, const int ne3, - const int ne10, const int ne11, const int ne12, const int ne13, - /*int s0, */ const int s1, const int s2, const int s3, - /*int s00,*/ const int s01, const int s02, const int s03, - /*int s10,*/ const int s11, const int s12, const int s13, - src1_ptrs... src1s) { - const int i0s = blockDim.x*blockIdx.x + threadIdx.x; - const int i1 = (blockDim.y*blockIdx.y + threadIdx.y); - const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3; - const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3; - - if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z) { return; } - const int i11 = i1 % ne11; - const int i12 = i2 % ne12; - const int i13 = i3 % ne13; + const uint32_t i11 = fastmodulo(i1, ne11); + const uint32_t i12 = fastmodulo(i2, ne12); + const uint32_t i13 = fastmodulo(i3, ne13); const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; @@ -53,8 +69,8 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; dst_t * dst_row = dst + i_dst; - for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) { - const int i10 = i0 % ne10; + for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) { + const uint32_t i10 = fastmodulo(i0, ne10); float result = src0_row ? (float) src0_row[i0] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { @@ -67,28 +83,48 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst } } -template -static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, - const int ne0, const int ne1, const int ne2,const int ne3, - const int ne10, const int ne11, const int ne12, const int ne13, - /*int s0, */ const int s1, const int s2, const int s3, - /*int s00,*/ const int s01, const int s02, const int s03, - /*int s10,*/ const int s11, const int s12, const int s13, - src1_ptrs ... src1s) { +template +static __global__ void k_bin_bcast_unravel(const src0_t * src0, + const src1_t * src1, + dst_t * dst, + const uint3 ne0, + const uint3 ne1, + const uint3 ne2, + const uint32_t ne3, + const uint3 prod_012, + const uint3 prod_01, + const uint3 ne10, + const uint3 ne11, + const uint3 ne12, + const uint3 ne13, + /*int s0, */ const int s1, + const int s2, + const int s3, + /*int s00,*/ const int s01, + const int s02, + const int s03, + /*int s10,*/ const int s11, + const int s12, + const int s13, + src1_ptrs... src1s) { const int i = blockDim.x*blockIdx.x + threadIdx.x; - const int i3 = i/(ne2*ne1*ne0); - const int i2 = (i/(ne1*ne0)) % ne2; - const int i1 = (i/ne0) % ne1; - const int i0 = i % ne0; + const uint32_t i3 = fastdiv(i, prod_012); + const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01); + const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0); + const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z; - if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) { return; } - const int i11 = i1 % ne11; - const int i12 = i2 % ne12; - const int i13 = i3 % ne13; + const int i11 = fastmodulo(i1, ne11); + const int i12 = fastmodulo(i2, ne12); + const int i13 = fastmodulo(i3, ne13); const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; @@ -97,7 +133,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; dst_t * dst_row = dst + i_dst; - const int i10 = i0 % ne10; + const int i10 = fastmodulo(i0, ne10); float result = src0_row ? (float) src0_row[i0] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { @@ -170,11 +206,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02); //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03); - int64_t ne10 = cne1[0]; - int64_t ne11 = cne1[1]; - int64_t ne12 = cne1[2]; - int64_t ne13 = cne1[3]; - size_t nb0 = cnb[0]; size_t nb1 = cnb[1]; size_t nb2 = cnb[2]; @@ -233,48 +264,51 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * block_dims.y = std::min(ne1, block_size / block_dims.x); block_dims.z = std::min(std::min(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U); - dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, - (ne1 + block_dims.y - 1) / block_dims.y, + dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y, (ne2 * ne3 + block_dims.z - 1) / block_dims.z); + const uint3 ne10 = init_fastdiv_values((uint32_t) cne1[0]); + const uint3 ne11 = init_fastdiv_values((uint32_t) cne1[1]); + const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]); + const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]); + if (block_nums.z > 65535) { - int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; + int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; + const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2)); + const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1)); + const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0); + const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1); + const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2); + if constexpr (sizeof...(I) > 0) { - k_bin_bcast_unravel - <<>>(src0_dd, src1_dd, dst_dd, - ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12,s13, - (const src1_t *) dst->src[I + 1]->data...); + k_bin_bcast_unravel<<>>( + src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, + ne12, ne13, + /* s0, */ s1, s2, s3, + /* s00,*/ s01, s02, s03, + /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } else { k_bin_bcast_unravel - <<>>(src0_dd, src1_dd, dst_dd, - ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12,s13); + <<>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, + ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s00,*/ s01, s02, s03, + /* s10,*/ s11, s12, s13); } } else { + const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3); if constexpr (sizeof...(I) > 0) { - k_bin_bcast - <<>>(src0_dd, src1_dd, dst_dd, - ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12,s13, - (const src1_t *) dst->src[I + 1]->data...); + k_bin_bcast<<>>( + src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s00,*/ s01, s02, s03, + /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } else { - k_bin_bcast - <<>>(src0_dd, src1_dd, dst_dd, - ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12,s13); + k_bin_bcast<<>>( + src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s00,*/ s01, s02, s03, + /* s10,*/ s11, s12, s13); } } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 4a882ab072..b54a1a4e82 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6050,6 +6050,9 @@ static std::vector> make_test_cases_eval() { add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2}); add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2}); + // test case for k_bin_bcast_unravel in CUDA backend + add_test_bin_bcast(type, {1, 1, 65536, 1}, {256, 1, 1, 1}); + // stable diffusion add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 1, 1, 1}); add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 16, 16, 1});