diff --git a/.github/ISSUE_TEMPLATE/010-bug-compilation.yml b/.github/ISSUE_TEMPLATE/010-bug-compilation.yml
index feb0d51205..c106f47a25 100644
--- a/.github/ISSUE_TEMPLATE/010-bug-compilation.yml
+++ b/.github/ISSUE_TEMPLATE/010-bug-compilation.yml
@@ -8,7 +8,8 @@ body:
value: >
Thanks for taking the time to fill out this bug report!
This issue template is intended for bug reports where the compilation of llama.cpp fails.
- Before opening an issue, please confirm that the compilation still fails with `-DGGML_CCACHE=OFF`.
+ Before opening an issue, please confirm that the compilation still fails
+ after recreating the CMake build directory and with `-DGGML_CCACHE=OFF`.
If the compilation succeeds with ccache disabled you should be able to permanently fix the issue
by clearing `~/.cache/ccache` (on Linux).
- type: textarea
diff --git a/.github/ISSUE_TEMPLATE/011-bug-results.yml b/.github/ISSUE_TEMPLATE/011-bug-results.yml
index b815e70a8d..31202dfa83 100644
--- a/.github/ISSUE_TEMPLATE/011-bug-results.yml
+++ b/.github/ISSUE_TEMPLATE/011-bug-results.yml
@@ -98,7 +98,18 @@ body:
label: Relevant log output
description: >
Please copy and paste any relevant log output, including the command that you entered and any generated text.
- This will be automatically formatted into code, so no need for backticks.
- render: shell
+ For very long logs (thousands of lines), preferably upload them as files instead.
+ On Linux you can redirect console output into a file by appending ` > llama.log 2>&1` to your command.
+ value: |
+
+ Logs
+
+
+ ```console
+
+ ```
+
+
+
validations:
required: true
diff --git a/.github/ISSUE_TEMPLATE/019-bug-misc.yml b/.github/ISSUE_TEMPLATE/019-bug-misc.yml
index e1bd08ddd2..8e867e7f60 100644
--- a/.github/ISSUE_TEMPLATE/019-bug-misc.yml
+++ b/.github/ISSUE_TEMPLATE/019-bug-misc.yml
@@ -85,8 +85,19 @@ body:
label: Relevant log output
description: >
If applicable, please copy and paste any relevant log output, including any generated text.
- This will be automatically formatted into code, so no need for backticks.
If you are encountering problems specifically with the `llama_params_fit` module, always upload `--verbose` logs as well.
- render: shell
+ For very long logs (thousands of lines), please upload them as files instead.
+ On Linux you can redirect console output into a file by appending ` > llama.log 2>&1` to your command.
+ value: |
+
+ Logs
+
+
+ ```console
+
+ ```
+
+
+
validations:
required: false
diff --git a/common/arg.cpp b/common/arg.cpp
index 1302065498..189470182a 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -2087,7 +2087,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"override tensor buffer type", [](common_params & params, const std::string & value) {
parse_tensor_buffer_overrides(value, params.tensor_buft_overrides);
}
- ));
+ ).set_env("LLAMA_ARG_OVERRIDE_TENSOR"));
add_opt(common_arg(
{"-otd", "--override-tensor-draft"}, "=,...",
"override tensor buffer type for draft model", [](common_params & params, const std::string & value) {
@@ -2137,11 +2137,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_N_CPU_MOE_DRAFT"));
+ GGML_ASSERT(params.n_gpu_layers < 0); // string_format would need to be extended for a default >= 0
add_opt(common_arg(
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
- string_format("max. number of layers to store in VRAM (default: %d)", params.n_gpu_layers),
- [](common_params & params, int value) {
- params.n_gpu_layers = value;
+ string_format("max. number of layers to store in VRAM, either an exact number, 'auto', or 'all' (default: %s)", params.n_gpu_layers == -1 ? "auto" : "all"),
+ [](common_params & params, const std::string & value) {
+ if (value == "auto") {
+ params.n_gpu_layers = -1;
+ } else if (value == "all") {
+ params.n_gpu_layers = -2;
+ } else {
+ params.n_gpu_layers = std::stoi(value);
+ }
if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: no usable GPU found, --gpu-layers option will be ignored\n");
fprintf(stderr, "warning: one possible reason is that llama.cpp was compiled without GPU support\n");
@@ -3175,11 +3182,19 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.devices = parse_device_list(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
+ GGML_ASSERT(params.speculative.n_gpu_layers < 0); // string_format would need to be extended for a default >= 0
add_opt(common_arg(
{"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N",
- "number of layers to store in VRAM for the draft model",
- [](common_params & params, int value) {
- params.speculative.n_gpu_layers = value;
+ string_format("max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: %s)",
+ params.speculative.n_gpu_layers == -1 ? "auto" : "all"),
+ [](common_params & params, const std::string & value) {
+ if (value == "auto") {
+ params.speculative.n_gpu_layers = -1;
+ } else if (value == "all") {
+ params.speculative.n_gpu_layers = -2;
+ } else {
+ params.speculative.n_gpu_layers = std::stoi(value);
+ }
if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: no usable GPU found, --gpu-layers-draft option will be ignored\n");
fprintf(stderr, "warning: one possible reason is that llama.cpp was compiled without GPU support\n");
@@ -3518,15 +3533,15 @@ void common_params_add_preset_options(std::vector & args) {
[](common_params &, const std::string &) { /* unused */ }
).set_env(COMMON_ARG_PRESET_LOAD_ON_STARTUP).set_preset_only());
+ args.push_back(common_arg(
+ {"stop-timeout"}, "SECONDS",
+ "in server router mode, force-kill model instance after this many seconds of graceful shutdown",
+ [](common_params &, int) { /* unused */ }
+ ).set_env(COMMON_ARG_PRESET_STOP_TIMEOUT).set_preset_only());
+
// args.push_back(common_arg(
// {"pin"},
// "in server router mode, do not unload this model if models_max is exceeded",
// [](common_params &) { /* unused */ }
// ).set_preset_only());
-
- // args.push_back(common_arg(
- // {"unload-idle-seconds"}, "SECONDS",
- // "in server router mode, unload models idle for more than this many seconds",
- // [](common_params &, int) { /* unused */ }
- // ).set_preset_only());
}
diff --git a/common/arg.h b/common/arg.h
index f5111c658f..a1b6a14e67 100644
--- a/common/arg.h
+++ b/common/arg.h
@@ -10,6 +10,7 @@
// pseudo-env variable to identify preset-only arguments
#define COMMON_ARG_PRESET_LOAD_ON_STARTUP "__PRESET_LOAD_ON_STARTUP"
+#define COMMON_ARG_PRESET_STOP_TIMEOUT "__PRESET_STOP_TIMEOUT"
//
// CLI argument parsing
diff --git a/common/common.cpp b/common/common.cpp
index acf2ec841d..8d62893370 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -1341,10 +1341,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
mparams.devices = params.devices.data();
}
- if (params.n_gpu_layers != -1) {
- mparams.n_gpu_layers = params.n_gpu_layers;
- }
-
+ mparams.n_gpu_layers = params.n_gpu_layers;
mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode;
mparams.tensor_split = params.tensor_split;
diff --git a/common/common.h b/common/common.h
index b6f8902189..55749dd8c7 100644
--- a/common/common.h
+++ b/common/common.h
@@ -329,7 +329,7 @@ struct common_params {
// offload params
std::vector devices; // devices to use for offloading
- int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
+ int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 16c5acf346..69abb7367d 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -7362,6 +7362,90 @@ class MiniMaxM2Model(TextModel):
return super().modify_tensors(data_torch, name, bid)
+@ModelBase.register("MiMoV2FlashForCausalLM")
+class MimoV2Model(TextModel):
+ model_arch = gguf.MODEL_ARCH.MIMO2
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+
+ assert self.hparams["swa_head_dim"] == self.hparams["head_dim"]
+ assert self.hparams["swa_num_attention_heads"] == self.hparams["num_attention_heads"]
+ assert self.hparams["swa_v_head_dim"] == self.hparams["v_head_dim"]
+ assert self.hparams["topk_method"] == "noaux_tc"
+
+ n_head_kv = self.hparams["num_key_value_heads"]
+ n_head_kv_swa = self.hparams["swa_num_key_value_heads"]
+ n_head_kv_arr = [n_head_kv_swa if use_swa == 1 else n_head_kv for use_swa in self.hparams["hybrid_layer_pattern"]]
+ self.gguf_writer.add_head_count_kv(n_head_kv_arr)
+
+ self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
+ self.gguf_writer.add_sliding_window_pattern(self.hparams["hybrid_layer_pattern"])
+ self.gguf_writer.add_rope_freq_base_swa(self.hparams["swa_rope_theta"])
+ self.gguf_writer.add_value_length(self.hparams["v_head_dim"])
+ self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"])
+ self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
+
+ rope_dim = int(self.hparams["head_dim"] * self.hparams["partial_rotary_factor"])
+ self.gguf_writer.add_rope_dimension_count(rope_dim)
+
+ self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon", 1e-5))
+
+ _experts: list[dict[str, Tensor]] | None = None
+
+ def modify_tensors(self, data_torch, name, bid):
+ if name.endswith("e_score_correction_bias"):
+ name = name.replace("e_score_correction_bias", "e_score_correction.bias")
+
+ if "attention_sink" in name and not name.endswith(".weight"):
+ name += ".weight"
+
+ # TODO: mimo v2 does not indicate the number of next-token-prediction layers, therefore we cannot do the same way as GLM4_MOE
+ if "model.mtp." in name:
+ return []
+
+ # process the experts separately
+ if name.find("mlp.experts") != -1:
+ n_experts = self.hparams["n_routed_experts"]
+ assert bid is not None
+
+ if self._experts is None:
+ self._experts = [{} for _ in range(self.block_count)]
+
+ self._experts[bid][name] = data_torch
+
+ if len(self._experts[bid]) >= n_experts * 3:
+ tensors: list[tuple[str, Tensor]] = []
+
+ # merge the experts into a single 3d tensor
+ for w_name in ["gate_proj", "up_proj", "down_proj"]:
+ datas: list[Tensor] = []
+
+ for xid in range(n_experts):
+ ename_to_retrieve = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
+ datas.append(self._experts[bid][ename_to_retrieve])
+ del self._experts[bid][ename_to_retrieve]
+
+ data_torch = torch.stack(datas, dim=0)
+ merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
+ new_name = self.map_tensor_name(merged_name)
+ tensors.append((new_name, data_torch))
+
+ return tensors
+ else:
+ return []
+ return [(self.map_tensor_name(name), data_torch)]
+
+ def prepare_tensors(self):
+ super().prepare_tensors()
+
+ if self._experts is not None:
+ # flatten `list[dict[str, Tensor]]` into `list[str]`
+ experts = [k for d in self._experts for k in d.keys()]
+ if len(experts) > 0:
+ raise ValueError(f"Unprocessed experts: {experts}")
+
+
@ModelBase.register("PanguEmbeddedForCausalLM")
class PanguEmbeddedModel(TextModel):
model_arch = gguf.MODEL_ARCH.PANGU_EMBED
@@ -8695,6 +8779,11 @@ class NemotronHModel(GraniteHybridModel):
raise ValueError(f"Unprocessed experts: {experts}")
+@ModelBase.register("LlamaBidirectionalModel")
+class LlamaEmbedNemotronModel(LlamaModel):
+ model_arch = gguf.MODEL_ARCH.LLAMA_EMBED
+
+
@ModelBase.register("BailingMoeForCausalLM")
class BailingMoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.BAILINGMOE
diff --git a/docs/backend/OPENCL.md b/docs/backend/OPENCL.md
index e52baffdff..ce6c7b5605 100644
--- a/docs/backend/OPENCL.md
+++ b/docs/backend/OPENCL.md
@@ -17,7 +17,7 @@ OpenCL (Open Computing Language) is an open, royalty-free standard for cross-pla
### Llama.cpp + OpenCL
-The llama.cpp OpenCL backend is designed to enable llama.cpp on **Qualcomm Adreno GPU** firstly via OpenCL. Thanks to the portabilty of OpenCL, the OpenCL backend can also run on certain Intel GPUs although the performance is not optimal.
+The llama.cpp OpenCL backend is designed to enable llama.cpp on **Qualcomm Adreno GPU** firstly via OpenCL. Thanks to the portabilty of OpenCL, the OpenCL backend can also run on certain Intel GPUs such as those that do not have [SYCL](/docs/backend/SYCL.md) support although the performance is not optimal.
## OS
diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md
index f44458ed3b..bcb3ce6743 100644
--- a/docs/backend/SYCL.md
+++ b/docs/backend/SYCL.md
@@ -829,7 +829,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
No. We can't support Ollama issue directly, because we aren't familiar with Ollama.
- Sugguest reproducing on llama.cpp and report similar issue to llama.cpp. We will surpport it.
+ Suggest reproducing on llama.cpp and report similar issue to llama.cpp. We will support it.
It's same for other projects including llama.cpp SYCL backend.
diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt
index 18d117f7cc..cb46c32100 100644
--- a/ggml/CMakeLists.txt
+++ b/ggml/CMakeLists.txt
@@ -430,10 +430,22 @@ if (MSVC)
configure_msvc_target(ggml-cpu-x64)
configure_msvc_target(ggml-cpu-sse42)
configure_msvc_target(ggml-cpu-sandybridge)
+ # __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
+ # skipping ggml-cpu-ivybridge
+ # skipping ggml-cpu-piledriver
configure_msvc_target(ggml-cpu-haswell)
configure_msvc_target(ggml-cpu-skylakex)
+ configure_msvc_target(ggml-cpu-cannonlake)
+ configure_msvc_target(ggml-cpu-cascadelake)
configure_msvc_target(ggml-cpu-icelake)
+ # MSVC 2022 doesn't support BF16 intrinsics without `/arch:AVX10.1` ?!
+ # https://learn.microsoft.com/en-us/cpp/intrinsics/x64-amd64-intrinsics-list?view=msvc-170
+ # https://learn.microsoft.com/en-us/cpp/build/reference/arch-x64?view=msvc-170
+ # skipping ggml-cpu-cooperlake
+ # skipping ggml-cpu-zen4
configure_msvc_target(ggml-cpu-alderlake)
+ # MSVC doesn't support AMX
+ # skipping ggml-cpu-sapphirerapids
if (GGML_BUILD_EXAMPLES)
configure_msvc_target(common-ggml)
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 262d78a4cf..25f25c4236 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -357,15 +357,29 @@ if (GGML_CPU_ALL_VARIANTS)
endif()
if (GGML_SYSTEM_ARCH STREQUAL "x86")
ggml_add_cpu_backend_variant(x64)
- ggml_add_cpu_backend_variant(sse42 SSE42)
- ggml_add_cpu_backend_variant(sandybridge SSE42 AVX)
- ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C AVX2 BMI2 FMA)
- ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512)
- ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
- ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI)
+ ggml_add_cpu_backend_variant(sse42 SSE42)
+ ggml_add_cpu_backend_variant(sandybridge SSE42 AVX)
+ if (NOT MSVC)
+ # __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
+ ggml_add_cpu_backend_variant(ivybridge SSE42 AVX F16C)
+ ggml_add_cpu_backend_variant(piledriver SSE42 AVX F16C FMA)
+ endif()
+ ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C FMA AVX2 BMI2)
+ ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C FMA AVX2 BMI2 AVX512)
+ ggml_add_cpu_backend_variant(cannonlake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI)
+ ggml_add_cpu_backend_variant(cascadelake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VNNI)
+ ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI)
+ if (NOT MSVC)
+ # MSVC 2022 doesn't support BF16 intrinsics without `/arch:AVX10.1` ?!
+ # https://learn.microsoft.com/en-us/cpp/intrinsics/x64-amd64-intrinsics-list?view=msvc-170
+ # https://learn.microsoft.com/en-us/cpp/build/reference/arch-x64?view=msvc-170
+ ggml_add_cpu_backend_variant(cooperlake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VNNI AVX512_BF16)
+ ggml_add_cpu_backend_variant(zen4 SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16)
+ endif()
+ ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C FMA AVX2 BMI2 AVX_VNNI)
if (NOT MSVC)
# MSVC doesn't support AMX
- ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
+ ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
endif()
elseif(GGML_SYSTEM_ARCH STREQUAL "ARM")
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp
index 835b53f659..2180a06fd0 100644
--- a/ggml/src/ggml-cann/aclnn_ops.cpp
+++ b/ggml/src/ggml-cann/aclnn_ops.cpp
@@ -2338,19 +2338,19 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
// Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
// TODO: acl_yarn_ramp_tensor use rope cache.
bool yarn_ramp_tensor_updated = false;
- ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
acl_tensor_ptr acl_yarn_ramp_tensor;
if (ext_factor != 0 && (theta_scale_updated || ctx.rope_cache.theta_scale_length != theta_scale_length ||
ctx.rope_cache.freq_scale != freq_scale)) {
yarn_ramp_tensor_updated = true;
-
+ if (ctx.rope_cache.yarn_ramp_cache != nullptr) {
+ ACL_CHECK(aclrtFree(ctx.rope_cache.yarn_ramp_cache));
+ }
+ ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
// -rope_yarn_ramp
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
// return MIN(1, MAX(0, y)) - 1;
- yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
- void * yarn_ramp_buffer = yarn_ramp_allocator.get();
acl_yarn_ramp_tensor =
- ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
+ ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
float zero_value = 0, one_value = 1;
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
@@ -2380,8 +2380,10 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
+ } else {
+ acl_yarn_ramp_tensor =
+ ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
}
-
// Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
if (ext_factor != 0) {
if (theta_scale_updated || yarn_ramp_tensor_updated) {
@@ -2988,32 +2990,156 @@ void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src.get(), 3, false, acl_dst.get());
}
-void ggml_cann_conv_transpose_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
+void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){
ggml_tensor * src0 = dst->src[0];
ggml_tensor * src1 = dst->src[1];
// stride
- int64_t s0 = ((const int32_t *) (dst->op_params))[0];
+ int64_t s0 = ((const int32_t*)(dst->op_params))[0];
- acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL);
+ acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL);
acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL);
- acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL);
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL);
+
+ // get base information of input and kernel
+ int64_t input_len = *(src1->ne);
+ int64_t dst_len = *(dst->ne);
+ int64_t kernel_size = *(src0->ne);
+
+ // set the max kernel size for each conv
+ int64_t max_kernel_size = 255;
+
+ // compute the partition of kernel
+ int64_t part_num = 1;
+ part_num = (kernel_size + max_kernel_size - 1) / max_kernel_size;
int64_t strideVal[1];
- strideVal[0] = s0;
- acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1);
- int64_t paddingVal[] = { 0 };
- acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1);
- int64_t dilationVal[] = { 1 };
- acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);
- int8_t cubeMathType = 0;
+ strideVal[0] = s0;
+ acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1);
+ int64_t paddingVal[] = {0};
+ acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1);
+ int64_t dilationVal[] = {1};
+ acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);
+ bool transposed = true;
+ int64_t groups = 1;
+ int8_t cubeMathType = 0;
#ifdef ASCEND_310P
cubeMathType = 1;
#endif
- GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), acl_weight.get(), nullptr, stride.get(), padding.get(),
- dilation.get(), true, padding.get(), 1, acl_dst.get(), cubeMathType);
+ auto weight_type = ggml_cann_type_mapping(src0->type);
+ auto dst_type = ggml_cann_type_mapping(dst->type);
+
+ // slice the kernel to make each conv available
+ int64_t slice_dim = -1;
+ int64_t slice_start = 0;
+ int64_t slice_end = max_kernel_size;
+ int64_t slice_step = 1;
+ int64_t interval = max_kernel_size;
+
+ int64_t left_pad_len = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0];
+ int64_t right_pad_len = 0;
+
+ acl_scalar_ptr alpha = nullptr;
+ float alphaValue = 1.0;
+ alpha = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT);
+
+ // set zero to destination
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
+
+ for(int k = 0; k < part_num; k++){
+
+ // create part kernel tensor and slice from big kernel
+ slice_start = max_kernel_size * k;
+ if(k == part_num - 1){
+ slice_end = kernel_size;
+ interval = kernel_size - max_kernel_size * k;
+ }else{
+ slice_end = max_kernel_size * (k+1);
+ }
+
+ int64_t part_ne[4];
+ for(int i = 0; i < 4; i++) {
+ part_ne[i] = *(src0->ne + i);
+ }
+ part_ne[0] = interval;
+
+ size_t part_nb[4];
+ part_nb[0] = sizeof(weight_type);
+ for (int i = 1; i < 4; i++) {
+ part_nb[i] = part_nb[i - 1] * part_ne[i - 1];
+ }
+
+ ggml_cann_pool_alloc part_kernel_allocator;
+ part_kernel_allocator.alloc(ctx.pool(), part_nb[3]);
+ void* part_kernel_buf = part_kernel_allocator.get();
+
+ acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type,
+ ggml_element_size(src0), part_ne, part_nb, 3, ACL_FORMAT_NCL);
+
+ GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step, part_kernel.get());
+
+ // create the part conv result tensor
+ int64_t part_dst_ne[4];
+ for(int i = 0; i < 4; i++){
+ part_dst_ne[i] = *(dst->ne + i);
+ }
+ part_dst_ne[0] = (input_len - 1) * strideVal[0] - 2 * paddingVal[0] + dilationVal[0] * (part_ne[0] - 1) + 1;
+
+ size_t part_dst_nb[4];
+ part_dst_nb[0] = sizeof(weight_type);
+ for (int i = 1; i < 4; i++) {
+ part_dst_nb[i] = part_dst_nb[i - 1] * part_dst_ne[i - 1];
+ }
+ ggml_cann_pool_alloc part_dst_allocator;
+ part_dst_allocator.alloc(ctx.pool(), part_dst_nb[3]);
+ void* part_dst_buf = part_dst_allocator.get();
+
+ acl_tensor_ptr acl_part_dst = ggml_cann_create_tensor(part_dst_buf, dst_type, ggml_element_size(dst),
+ part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL);
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_part_dst.get());
+
+ // compute part conv transpose 1d
+ GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), part_kernel.get(), nullptr, stride.get(),
+ padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(), cubeMathType);
+
+ // compute the position of part result in final result
+ int64_t global_start = slice_start;
+ int64_t global_end = std::min((input_len - 1) * strideVal[0] + slice_end, dst_len);
+
+ left_pad_len = global_start;
+ right_pad_len = dst_len - global_end;
+
+ std::vector padDataVal = {left_pad_len,right_pad_len};
+ acl_int_array_ptr padData = ggml_cann_create_int_array(padDataVal.data(), 2);
+
+ acl_scalar_ptr pad_value = nullptr;
+ float pad_valueVal = 0.0;
+ pad_value = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT);
+
+ int64_t conv_result_ne[4];
+ for(int i = 0; i < 4; i++){
+ conv_result_ne[i] = *(dst->ne + i);
+ }
+
+ size_t conv_result_nb[4];
+ conv_result_nb[0] = sizeof(weight_type);
+ for (int i = 1; i < 4; i++) {
+ conv_result_nb[i] = conv_result_nb[i - 1] * conv_result_ne[i - 1];
+ }
+
+ ggml_cann_pool_alloc conv_result_allocator;
+ conv_result_allocator.alloc(ctx.pool(), conv_result_nb[3]);
+ void* conv_result_buf = conv_result_allocator.get();
+
+ acl_tensor_ptr conv_result = ggml_cann_create_tensor(conv_result_buf, dst_type, ggml_element_size(dst),
+ conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL);
+
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, conv_result.get());
+ GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(), conv_result.get());
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst.get(), conv_result.get(), alpha.get());
+ }
}
void ggml_cann_elu(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
@@ -3576,3 +3702,106 @@ void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
break;
}
}
+
+void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
+ ggml_tensor * src0 = dst->src[0]; // conv_x
+ ggml_tensor * src1 = dst->src[1]; // conv1d.weight
+
+ // This op is currently defined only for F32 in ggml_cpu
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ // Shapes follow ggml_compute_forward_ssm_conv_f32
+ const int64_t nc = src1->ne[0]; // d_conv
+ const int64_t ncs = src0->ne[0]; // d_conv - 1 + n_t
+ const int64_t nr = src0->ne[1]; // d_inner
+ const int64_t n_s = src0->ne[2]; // n_seqs
+
+ const int64_t n_t = dst->ne[1]; // tokens per sequence
+
+ GGML_ASSERT(dst->ne[0] == nr); // dst: {d_inner, n_t, n_s}
+ GGML_ASSERT(src1->ne[1] == nr); // weight: {d_conv, d_inner}
+ GGML_ASSERT(ncs == nc - 1 + n_t); // conv_x: {d_conv - 1 + n_t, d_inner, n_s}
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
+
+ // --- Build CANN tensors ---
+
+ // 1) Input: conv_x as NCL
+ //
+ // src0->ne = { ncs, nr, n_s, 1 } // {L_in, C, N}
+ // Passing ACL_FORMAT_NCL here means:
+ // reversed dims -> [N, C, L_in] = [n_s, nr, ncs]
+ acl_tensor_ptr acl_x = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL);
+
+ // 2) Weights: depthwise conv kernel, view src1 as {K, 1, C}
+ //
+ // src1 original: ne = { nc, nr, 1, 1 } // [K, C, 1, 1]
+ // we want a view: ne_w = { nc, 1, nr } // [K, 1, C]
+ // so that reversed dims -> [C, 1, K] which matches
+ // [out_channels, in_channels/groups, kernel_size]
+ int64_t w_ne[GGML_MAX_DIMS] = { nc, 1, nr, 1 }; // [K, 1 input ch. per group, C groups]
+ // Layout: src1 data is [K, C] with
+ // offset(k, c) = k*nb0 + c*nb1
+ // We want offset_w(k, 0, c) = k*nb0 + c*nb1,
+ // so we can reuse nb0 and nb1, and set nb2 = nb1.
+ size_t w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1
+
+ acl_tensor_ptr acl_w = ggml_cann_create_tensor(
+ src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL);
+
+ // 3) Output: dst is { d_inner, n_t, n_s } (CLN)
+ //
+ // We need an NCL view of the same buffer:
+ // desired NCL logical shape: { L_out = n_t, C = nr, N = n_s }
+ //
+ // Original CLN layout:
+ // dst->ne = { nr, n_t, n_s }
+ // dst->nb[0] = sizeof(float)
+ // dst->nb[1] = nr * sizeof(float)
+ // dst->nb[2] = nr * n_t * sizeof(float)
+ //
+ // We want offset_new(L, C, N) = offset_orig(C, L, N).
+ // Choose:
+ // nb_y[0] = nr * sizeof(float); // step in L
+ // nb_y[1] = sizeof(float); // step in C
+ // nb_y[2] = nr * n_t * sizeof(float); // step in N
+ int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N]
+ size_t y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float), dst->nb[3] }; // [nr, 1, nr * n_t]
+
+ acl_tensor_ptr acl_y = ggml_cann_create_tensor(
+ dst->data, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL);
+
+ // --- Conv1d parameters: depthwise, stride 1, no padding ("valid") ---
+ int64_t strideVal[1] = { 1 };
+ int64_t paddingVal[1] = { 0 };
+ int64_t dilationVal[1] = { 1 };
+
+ acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1);
+ acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1);
+ acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);
+
+ const bool transposed = false;
+ const int64_t groups = nr; // depthwise: one group per inner dim
+ int8_t cubeMathType = 0;
+
+#ifdef ASCEND_310P
+ cubeMathType = 1;
+#endif
+
+ GGML_CANN_CALL_ACLNN_OP(ctx,
+ Convolution,
+ acl_x.get(), // input: N, C, L_in = ncs
+ acl_w.get(), // weight: [C, 1, K] with groups=nr
+ nullptr, // bias
+ stride.get(),
+ padding.get(),
+ dilation.get(),
+ transposed,
+ padding.get(), // output padding (unused for non-transposed)
+ groups,
+ acl_y.get(),
+ cubeMathType);
+}
+
diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h
index 1ebbc769c7..a6ea016c54 100644
--- a/ggml/src/ggml-cann/aclnn_ops.h
+++ b/ggml/src/ggml-cann/aclnn_ops.h
@@ -47,6 +47,7 @@
#include
#include
#include
+#include
#include
#include
@@ -1032,6 +1033,8 @@ void ggml_cann_op_unary(std::functiondata != this->node_address && node->op != GGML_OP_VIEW) {
+ return false;
+ }
+
+ if (node->op != this->node_op) {
+ return false;
+ }
+
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ if (node->ne[i] != this->ne[i]) {
+ return false;
+ }
+ if (node->nb[i] != this->nb[i]) {
+ return false;
+ }
+ }
+
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ if (node->src[i]) {
+ if (node->src[i]->data != this->src_address[i] && node->op != GGML_OP_VIEW) {
+ return false;
+ }
+
+ for (int d = 0; d < GGML_MAX_DIMS; d++) {
+ if (node->src[i]->ne[d] != this->src_ne[i][d]) {
+ return false;
+ }
+ if (node->src[i]->nb[d] != this->src_nb[i][d]) {
+ return false;
+ }
+ }
+ } else {
+ if (this->src_address[i] != nullptr) {
+ return false;
+ }
+ }
+ }
+
+ if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
+ return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
+ }
+ return true;
+ }
};
struct ggml_cann_graph {
@@ -241,6 +295,79 @@ struct ggml_cann_graph {
aclmdlRI graph = nullptr;
std::vector ggml_graph_properties;
+
+ /**
+ * @brief Create a new CANN graph from a ggml computation graph.
+ *
+ * This function creates a new ggml_cann_graph object and fills its node properties
+ * (operation type, dimensions, strides, input sources, and operation parameters)
+ * based on the current ggml computation graph.
+ *
+ * Each node in the ggml graph is mapped to a property entry in the new CANN graph:
+ * - node address
+ * - operation type
+ * - shape (ne) and strides (nb)
+ * - source tensor addresses
+ * - operation parameters
+ *
+ * @param cgraph The current ggml computation graph.
+ * @return Pointer to the newly created ggml_cann_graph object.
+ */
+ static ggml_cann_graph * create_from_cgraph(ggml_cgraph * cgraph) {
+ ggml_cann_graph * new_graph = new ggml_cann_graph();
+ new_graph->ggml_graph_properties.resize(cgraph->n_nodes);
+
+ for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {
+ ggml_tensor * node = cgraph->nodes[node_idx];
+ auto & prop = new_graph->ggml_graph_properties[node_idx];
+
+ prop.node_address = node->data;
+ prop.node_op = node->op;
+
+ std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);
+ std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
+
+ for (int src = 0; src < GGML_MAX_SRC; ++src) {
+ if (node->src[src]) {
+ prop.src_address[src] = node->src[src]->data;
+ std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
+ std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
+ } else {
+ prop.src_address[src] = nullptr;
+ std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
+ std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
+ }
+ }
+
+ memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
+ }
+
+ return new_graph;
+ }
+
+ /**
+ * @brief Check whether this CANN graph matches the given ggml computation graph.
+ *
+ * This function compares the number of nodes and each node's properties
+ * (operation type, dimensions, strides, inputs, and operation parameters)
+ * to determine whether this CANN graph matches the given ggml graph.
+ *
+ * @param cgraph The current ggml computation graph.
+ * @return true if this CANN graph matches the ggml graph; false otherwise.
+ */
+ bool matches_cgraph(ggml_cgraph * cgraph) {
+ if (this->ggml_graph_properties.size() != static_cast(cgraph->n_nodes)) {
+ return false;
+ }
+
+ for (int i = 0; i < cgraph->n_nodes; ++i) {
+ if (!this->ggml_graph_properties[i].has_matching_properties(cgraph->nodes[i])) {
+ return false;
+ }
+ }
+
+ return true;
+ }
};
/**
@@ -272,15 +399,6 @@ struct ggml_cann_graph_lru_cache {
cache_list.push_front(new_node);
}
- /**
- * @brief Move an existing graph to the front of the cache.
- * @param node Pointer to the ggml_cann_graph to move.
- */
- void move_to_front(ggml_cann_graph * node) {
- cache_list.remove(node);
- cache_list.push_front(node);
- }
-
/**
* @brief Clear all graphs from the cache (also frees memory).
*/
@@ -295,6 +413,28 @@ struct ggml_cann_graph_lru_cache {
* @brief Destructor that clears the cache and frees all cached graphs.
*/
~ggml_cann_graph_lru_cache() { clear(); }
+
+ /**
+ * @brief Find a cached CANN graph that matches the given ggml graph and move it to front.
+ *
+ * This function iterates through the cached CANN graphs stored in the LRU cache and
+ * compares them against the given ggml computation graph. If a matching graph is found,
+ * it is promoted to the front of the LRU cache and returned. Otherwise, the function
+ * returns nullptr.
+ *
+ * @param cgraph The current ggml computation graph.
+ * @return true if found; false otherwise.
+ */
+ bool find_and_move_to_front(ggml_cgraph * cgraph) {
+ for (auto & graph_ptr : this->cache_list) {
+ if (graph_ptr->matches_cgraph(cgraph)) {
+ cache_list.remove(graph_ptr);
+ cache_list.push_front(graph_ptr);
+ return true;
+ }
+ }
+ return false;
+ }
};
#endif // USE_ACL_GRAPH
@@ -318,6 +458,9 @@ struct ggml_cann_rope_cache {
if (position_select_index_host) {
free(position_select_index_host);
}
+ if (yarn_ramp_cache) {
+ ACL_CHECK(aclrtFree(yarn_ramp_cache));
+ }
}
bool equal(int64_t theta_scale_length,
@@ -370,6 +513,7 @@ struct ggml_cann_rope_cache {
float * theta_scale_exp_host = nullptr;
int * position_select_index_host = nullptr;
void * position_select_index = nullptr;
+ void * yarn_ramp_cache = nullptr;
// sin/cos cache, used only to accelerate first layer on each device
void * sin_cache = nullptr;
void * cos_cache = nullptr;
diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp
index da624c587c..ef23ec78da 100644
--- a/ggml/src/ggml-cann/ggml-cann.cpp
+++ b/ggml/src/ggml-cann/ggml-cann.cpp
@@ -1888,6 +1888,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
break;
case GGML_OP_OUT_PROD:
ggml_cann_out_prod(ctx, dst);
+ case GGML_OP_SSM_CONV:
+ ggml_cann_ssm_conv(ctx, dst);
break;
default:
return false;
@@ -2075,162 +2077,6 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
}
-#ifdef USE_ACL_GRAPH
-/**
- * @brief Add a new CANN graph to the LRU cache by populating node properties from the ggml graph.
- *
- * This function creates a new ggml_cann_graph object and fills its node properties
- * (operation type, dimensions, strides, input sources, and operation parameters)
- * based on the current ggml computation graph.
- *
- * Each node in the ggml graph is mapped to a property entry in the new CANN graph:
- * - node address
- * - operation type
- * - shape (ne) and strides (nb)
- * - source tensor addresses
- * - operation parameters
- *
- * After initialization, the new graph is pushed into the LRU cache owned by the
- * CANN backend context. The cache takes ownership of the graph and manages its
- * lifetime (including deletion upon eviction).
- *
- * @param cann_ctx The CANN backend context containing the graph cache.
- * @param cgraph The current ggml computation graph.
- */
-static void add_lru_matched_graph_node_properties(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
- // Create a new ggml_cann_graph object on the heap (its lifetime is managed by the cache).
- ggml_cann_graph * new_graph = new ggml_cann_graph();
- new_graph->ggml_graph_properties.resize(cgraph->n_nodes);
-
- for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {
- ggml_tensor * node = cgraph->nodes[node_idx];
- auto & prop = new_graph->ggml_graph_properties[node_idx];
-
- prop.node_address = node->data;
- prop.node_op = node->op;
-
- std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);
- std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
-
- for (int src = 0; src < GGML_MAX_SRC; ++src) {
- if (node->src[src]) {
- prop.src_address[src] = node->src[src]->data;
- std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
- std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
- } else {
- prop.src_address[src] = nullptr;
- std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
- std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
- }
- }
-
- memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
- }
-
- // Insert into the LRU cache (cache takes ownership and will delete it when evicted).
- cann_ctx->graph_lru_cache.push(new_graph);
-}
-
-/**
- * @brief Check if a ggml tensor node matches a previously captured CANN graph node.
- *
- * This function compares all relevant fields (address, op type, shape, source inputs, op params)
- * to determine whether the current node matches a previously recorded version.
- *
- * @param node The current ggml tensor node.
- * @param graph_node_properties The stored properties of a CANN graph node.
- * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
- */
-static bool ggml_graph_node_has_matching_properties(ggml_tensor * node,
- ggml_graph_node_properties * graph_node_properties) {
- if (node->data != graph_node_properties->node_address && node->op != GGML_OP_VIEW) {
- return false;
- }
-
- if (node->op != graph_node_properties->node_op) {
- return false;
- }
-
- for (int i = 0; i < GGML_MAX_DIMS; i++) {
- if (node->ne[i] != graph_node_properties->ne[i]) {
- return false;
- }
- if (node->nb[i] != graph_node_properties->nb[i]) {
- return false;
- }
- }
-
- for (int i = 0; i < GGML_MAX_SRC; i++) {
- if (node->src[i]) {
- if (node->src[i]->data != graph_node_properties->src_address[i] && node->op != GGML_OP_VIEW) {
- return false;
- }
-
- for (int d = 0; d < GGML_MAX_DIMS; d++) {
- if (node->src[i]->ne[d] != graph_node_properties->src_ne[i][d]) {
- return false;
- }
- if (node->src[i]->nb[d] != graph_node_properties->src_nb[i][d]) {
- return false;
- }
- }
- } else {
- if (graph_node_properties->src_address[i] != nullptr) {
- return false;
- }
- }
- }
-
- if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
- return memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
- }
- return true;
-}
-
-/**
- * @brief Check whether there is a cached CANN graph that matches the current ggml graph.
- *
- * This function iterates through the cached CANN graphs stored in the LRU cache and
- * compares them against the given ggml computation graph. A match requires that the
- * number of nodes is the same and that each node’s properties (operation type,
- * dimensions, strides, inputs, and operation parameters) are identical.
- *
- * If a matching graph is found, it is promoted to the front of the LRU cache and the
- * function returns true. Otherwise, the function returns false, indicating that a new
- * CANN graph needs to be captured.
- *
- * @param cann_ctx The CANN backend context containing the graph cache.
- * @param cgraph The current ggml computation graph.
- * @return true if a matching cached graph exists; false otherwise.
- */
-static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
- ggml_cann_graph_lru_cache & lru_cache = cann_ctx->graph_lru_cache;
- for (auto & graph_ptr : lru_cache.cache_list) {
- // Skip graphs with a different number of nodes.
- if (graph_ptr->ggml_graph_properties.size() != static_cast(cgraph->n_nodes)) {
- continue;
- }
-
- // Check if all nodes match.
- bool all_match = true;
- for (int i = 0; i < cgraph->n_nodes; ++i) {
- if (!ggml_graph_node_has_matching_properties(cgraph->nodes[i], &graph_ptr->ggml_graph_properties[i])) {
- all_match = false;
- break;
- }
- }
-
- if (all_match) {
- // update cache_list && renturn graph_ptr
- lru_cache.move_to_front(graph_ptr);
- return true;
- }
- }
-
- return false;
-}
-#endif // USE_ACL_GRAPH
-
/**
* @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
*
@@ -2239,23 +2085,23 @@ static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph *
*
* Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.
*
- * @param cann_ctx The CANN backend context.
- * @param cgraph The ggml computation graph.
- * @param use_cann_graph Whether to use CANN graph execution.
- * @param cann_graph_update_required Whether graph capture is needed due to graph changes.
+ * @param cann_ctx The CANN backend context.
+ * @param cgraph The ggml computation graph.
+ * @param use_cann_graph Whether to use CANN graph execution.
+ * @param cann_graph_capture_required Whether graph capture is needed due to graph changes.
*/
static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx,
ggml_cgraph * cgraph,
- bool & use_cann_graph,
- bool & cann_graph_update_required) {
+ bool use_cann_graph,
+ bool cann_graph_capture_required) {
#ifdef USE_ACL_GRAPH
- if (use_cann_graph && cann_graph_update_required) { // Begin CANN graph capture
+ if (use_cann_graph && cann_graph_capture_required) { // Begin CANN graph capture
ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
}
#endif // USE_ACL_GRAPH
// Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
// With the use of CANN graphs, the execution will be performed by the graph launch.
- if (!use_cann_graph || cann_graph_update_required) {
+ if (!use_cann_graph || cann_graph_capture_required) {
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
@@ -2274,9 +2120,10 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
#ifdef USE_ACL_GRAPH
if (use_cann_graph) {
+ GGML_ASSERT(!cann_ctx->graph_lru_cache.cache_list.empty());
ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
- if (cann_graph_update_required) { // End CANN graph capture
+ if (cann_graph_capture_required) { // End CANN graph capture
ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
}
@@ -2306,7 +2153,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
// calculate rope cache for fist layer in current device.
cann_ctx->rope_cache.cached = false;
- bool cann_graph_update_required = false;
+ bool graph_capture_required = false;
#ifdef USE_ACL_GRAPH
bool use_cann_graph = true;
@@ -2331,16 +2178,17 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
if (use_cann_graph) {
// If no matching graph is found, the graph needs to be recaptured.
- cann_graph_update_required = !is_matched_graph(cann_ctx, cgraph);
- if (cann_graph_update_required) {
+ graph_capture_required = !cann_ctx->graph_lru_cache.find_and_move_to_front(cgraph);
+ if (graph_capture_required) {
// If no matching graph is found, add a new ACL graph.
- add_lru_matched_graph_node_properties(cann_ctx, cgraph);
+ ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph);
+ cann_ctx->graph_lru_cache.push(new_graph);
}
}
#else
bool use_cann_graph = false;
#endif // USE_ACL_GRAPH
- evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, cann_graph_update_required);
+ evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, graph_capture_required);
return GGML_STATUS_SUCCESS;
}
@@ -2578,8 +2426,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
}
}
case GGML_OP_CONV_TRANSPOSE_1D:
- // TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255.
- return (op->src[0]->ne[0] - 1) <= 255;
+ return true;
case GGML_OP_SCALE:
float bias;
memcpy(&bias, (const float *) (op->op_params) + 1, sizeof(float));
@@ -2626,6 +2473,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
}
return true;
}
+ case GGML_OP_SSM_CONV:
+ return true;
default:
return false;
}
diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h
index 7597377cc2..0e8dd0ae05 100644
--- a/ggml/src/ggml-cpu/ggml-cpu-impl.h
+++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h
@@ -328,7 +328,7 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
#if defined(_MSC_VER) || defined(__MINGW32__)
#include
-#elif defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
+#elif defined(__SSE__) || defined(__SSE3__) || defined(__SSSE3__) || defined(__AVX__) || defined(__F16C__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX512BF16__)
#include
#endif
diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h
index 101a9c086b..a7a8272205 100644
--- a/ggml/src/ggml-cpu/simd-mappings.h
+++ b/ggml/src/ggml-cpu/simd-mappings.h
@@ -14,10 +14,6 @@
#include
#endif
-#if defined(__F16C__)
-#include
-#endif
-
#if defined(__riscv_v_intrinsic)
#include
#endif
diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt
index 67af1d8ccc..f491217545 100644
--- a/ggml/src/ggml-cuda/CMakeLists.txt
+++ b/ggml/src/ggml-cuda/CMakeLists.txt
@@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND)
# 80 == Ampere, asynchronous data loading, faster tensor core instructions
# 86 == RTX 3000, needs CUDA v11.1
# 89 == RTX 4000, needs CUDA v11.8
+ # 120 == Blackwell, needs CUDA v12.8, FP4 tensor cores
#
# XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run
# XX-real == compile CUDA code as device code for this specific architecture
@@ -36,10 +37,36 @@ if (CUDAToolkit_FOUND)
endif()
endif()
endif()
- message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
enable_language(CUDA)
+ # Replace any 12x-real architectures with 12x{a}-real. FP4 ptx instructions are not available in just 12x
+ if (GGML_NATIVE)
+ set(PROCESSED_ARCHITECTURES "")
+ if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES AND CMAKE_CUDA_ARCHITECTURES_NATIVE)
+ set(ARCH_LIST ${CMAKE_CUDA_ARCHITECTURES_NATIVE})
+ else()
+ set(ARCH_LIST ${CMAKE_CUDA_ARCHITECTURES})
+ endif()
+ foreach(ARCH ${ARCH_LIST})
+ if (ARCH MATCHES "^12[0-9](-real|-virtual)?$")
+ string(REGEX REPLACE "^(12[0-9]).*$" "\\1" BASE_ARCH ${ARCH})
+ message(STATUS "Replacing ${ARCH} with ${BASE_ARCH}a-real")
+ list(APPEND PROCESSED_ARCHITECTURES "${BASE_ARCH}a-real")
+ else()
+ list(APPEND PROCESSED_ARCHITECTURES ${ARCH})
+ endif()
+ endforeach()
+ set(CMAKE_CUDA_ARCHITECTURES ${PROCESSED_ARCHITECTURES})
+ else()
+ foreach(ARCH ${CMAKE_CUDA_ARCHITECTURES})
+ if(ARCH MATCHES "^12[0-9](-real|-virtual)?$")
+ message(FATAL_ERROR "Compute capability ${ARCH} used, use ${ARCH}a or ${ARCH}f for Blackwell specific optimizations")
+ endif()
+ endforeach()
+ endif()
+ message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
+
file(GLOB GGML_HEADERS_CUDA "*.cuh")
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 9fcb2f9fd2..62e618850b 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -50,6 +50,10 @@
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_ADA_LOVELACE 890
+// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see
+// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms
+#define GGML_CUDA_CC_BLACKWELL 1200
+#define GGML_CUDA_CC_RUBIN 1300
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
@@ -246,6 +250,10 @@ static const char * cu_get_error_str(CUresult err) {
#define AMPERE_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL && __CUDA_ARCH__ < GGML_CUDA_CC_RUBIN
+# define BLACKWELL_MMA_AVAILABLE
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL
+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define CP_ASYNC_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
@@ -316,6 +324,11 @@ static bool cp_async_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
}
+static bool blackwell_mma_available(const int cc) {
+ return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_BLACKWELL &&
+ ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_RUBIN;
+}
+
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
return 64;
@@ -701,6 +714,28 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
#endif // CUDART_VERSION >= 12050
}
+__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
+ const uint8_t sign_bit = (x < 0.0f) << 3;
+ float ax = fabsf(x) * e;
+
+ // Positive LUT
+ static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f };
+
+ int best_i = 0;
+ float best_err = fabsf(ax - pos_lut[0]);
+
+#pragma unroll
+ for (int i = 1; i < 8; ++i) {
+ const float err = fabsf(ax - pos_lut[i]);
+ if (err < best_err) {
+ best_err = err;
+ best_i = i;
+ }
+ }
+
+ return static_cast(best_i | sign_bit);
+}
+
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
// Precompute mp (m' in the paper) and L such that division
// can be computed using a multiply (high 32b of 64b result)
diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu
index d2f2def8bd..e82171f9c2 100644
--- a/ggml/src/ggml-cuda/cumsum.cu
+++ b/ggml/src/ggml-cuda/cumsum.cu
@@ -5,7 +5,7 @@
#include "ggml.h"
#ifdef GGML_CUDA_USE_CUB
-# include
+# include
#endif // GGML_CUDA_USE_CUB
template
@@ -16,12 +16,14 @@ static __global__ void cumsum_cub_kernel(
const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t s1, const int64_t s2, const int64_t s3) {
#ifdef GGML_CUDA_USE_CUB
- using BlockScan = cub::BlockScan;
+ using BlockScanT = cub::BlockScan;
- __shared__ typename BlockScan::TempStorage temp_storage;
- __shared__ T block_carry; // carry from previous tile
+ __shared__ typename BlockScanT::TempStorage temp_storage;
+ __shared__ T block_carry;
const int tid = threadIdx.x;
+ constexpr int UNROLL_FACTOR = 4;
+ constexpr int TILE_SIZE = BLOCK_SIZE * UNROLL_FACTOR;
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.y;
@@ -39,29 +41,38 @@ static __global__ void cumsum_cub_kernel(
}
__syncthreads();
- for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) {
- int64_t idx = start + tid;
- T x = (idx < ne00) ? src_row[idx] : T(0);
+ for (int64_t start = 0; start < ne00; start += TILE_SIZE) {
+ T items[UNROLL_FACTOR];
+ T thread_sum = T(0);
- T inclusive;
- T block_total;
- BlockScan(temp_storage).InclusiveSum(x, inclusive, block_total);
-
- __syncthreads();
-
- T final_val = inclusive + block_carry;
-
- // store result
- if (idx < ne00) {
- dst_row[idx] = final_val;
+#pragma unroll
+ for (int i = 0; i < UNROLL_FACTOR; i++) {
+ int64_t idx = start + tid * UNROLL_FACTOR + i;
+ T val = (idx < ne00) ? src_row[idx] : T(0);
+ thread_sum += val;
+ items[i] = thread_sum;
}
+ // Block-wide scan on thread sums
+ T thread_prefix;
+ T block_total;
+ BlockScanT(temp_storage).InclusiveSum(thread_sum, thread_prefix, block_total);
__syncthreads();
+ // Add offset to each item and store
+ T thread_offset = thread_prefix - thread_sum + block_carry;
+ #pragma unroll
+ for (int i = 0; i < UNROLL_FACTOR; i++) {
+ int64_t idx = start + tid * UNROLL_FACTOR + i;
+ if (idx < ne00) {
+ dst_row[idx] = items[i] + thread_offset;
+ }
+ }
+
+ // Update carry for next tile
if (tid == 0) {
block_carry += block_total;
}
-
__syncthreads();
}
#else
@@ -69,7 +80,7 @@ static __global__ void cumsum_cub_kernel(
#endif // GGML_CUDA_USE_CUB
}
-// Fallback kernel implementation (original)
+// Fallback kernel implementation
template
static __global__ void cumsum_kernel(
const T * src, T * dst,
@@ -86,10 +97,10 @@ static __global__ void cumsum_kernel(
const int warps_per_block = blockDim.x / warp_size;
extern __shared__ float smem[];
- float * s_vals = smem;
- float * s_warp_sums = smem + blockDim.x;
- float * s_carry = smem + blockDim.x + warps_per_block;
- float * s_chunk_total = s_carry + 1;
+ float * s_vals = smem;
+ float * s_warp_sums = smem + blockDim.x;
+ float * s_carry = smem + blockDim.x + warps_per_block;
+ float * s_chunk_total = s_carry + 1;
// Initialize carry
if (tid == 0) {
@@ -107,21 +118,39 @@ static __global__ void cumsum_kernel(
const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
- for (int64_t start = 0; start < ne00; start += blockDim.x) {
- int64_t idx = start + tid;
- float val = (idx < ne00) ? ggml_cuda_cast(src_row[idx]) : 0.0f;
+ // register blocking: process 4 elements per thread to hide latency
+ // and reduce synchronization overhead
+ constexpr int num_unroll = 4;
+ T temp[num_unroll];
- // 1. Warp inclusive scan
+ for (int64_t i = 0; i < ne00; i += num_unroll * blockDim.x) {
+ int64_t idx = i + tid * num_unroll;
+
+ // thread local sequential scan
+ temp[0] = (idx < ne00 ? src_row[idx] : T(0));
+#pragma unroll
+ for (int64_t j = 1; j < num_unroll; j++) {
+ temp[j] = temp[j - 1];
+ if (idx + j < ne00) {
+ temp[j] += src_row[idx + j];
+ } else {
+ temp[j] += 0;
+ }
+ }
+
+ // last emenent is sum of all values assigned to thread
+ float val = (idx < ne00) ? ggml_cuda_cast(temp[num_unroll - 1]) : 0.0f;
+
+ // Warp inclusive scan
val = warp_prefix_inclusive_sum(val);
s_vals[tid] = val;
- // Store warp total
if (lane == warp_size - 1) {
s_warp_sums[warp] = val;
}
__syncthreads();
- // 2. Exclusive scan of warp sums (warp 0 only)
+ // Exclusive scan of warp sums (warp 0 only)
if (warp == 0) {
float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
float inc = warp_prefix_inclusive_sum(w);
@@ -134,12 +163,17 @@ static __global__ void cumsum_kernel(
}
__syncthreads();
+ // write back results
float carry = *s_carry;
- float final_val = s_vals[tid] + s_warp_sums[warp] + carry;
- if (idx < ne00) {
- dst_row[idx] = ggml_cuda_cast(final_val);
+ // calculate sum offset for this thread
+ float final_val_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1];
+
+#pragma unroll
+ for (int32_t j = 0; j < num_unroll; j++) {
+ if (idx + j < ne00) {
+ dst_row[idx + j] = temp[j] + ggml_cuda_cast(final_val_offset);
+ }
}
- __syncthreads();
// Update carry for next chunk
if (tid == 0) {
@@ -177,7 +211,7 @@ static void cumsum_cuda(
const int warps_per_block = block_size / warp_size;
const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
- if (use_cub) {
+ if (use_cub && ne00 >= 1024) {
cumsum_cub_kernel<<>>(
src, dst,
ne00, ne01, ne02, ne03,
diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh
index 3268dadfe8..df9eed7117 100644
--- a/ggml/src/ggml-cuda/mma.cuh
+++ b/ggml/src/ggml-cuda/mma.cuh
@@ -900,6 +900,27 @@ namespace ggml_cuda_mma {
#endif // AMPERE_MMA_AVAILABLE
}
+ static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
+ const tile<16, 8, int> & A,
+ const tile<8, 8, int> & B,
+ uint32_t a_scale,
+ uint32_t b_scale) {
+#ifdef BLACKWELL_MMA_AVAILABLE
+ const int * Axi = (const int *) A.x;
+ const int * Bxi = (const int *) B.x;
+ float * Dxi = (float *) D.x;
+
+ asm volatile(
+ "mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
+ "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
+ "%10, {0, 0}, %11, {0, 0};"
+ : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
+#else
+ GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
+#endif // BLACKWELL_MMA_AVAILABLE
+ }
+
static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
#ifdef TURING_MMA_AVAILABLE
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
index f7a2cbca90..6156dcdae7 100644
--- a/ggml/src/ggml-cuda/mmq.cu
+++ b/ggml/src/ggml-cuda/mmq.cu
@@ -1,3 +1,4 @@
+#include "common.cuh"
#include "mmq.cuh"
#include "quantize.cuh"
#include "mmid.cuh"
@@ -114,6 +115,9 @@ void ggml_cuda_mul_mat_q(
const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| GGML_CUDA_CC_IS_CDNA(cc);
+ // TODO: tighter pool buffer size vs q8 path
+ const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
+
if (!ids) {
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
@@ -123,12 +127,24 @@ void ggml_cuda_mul_mat_q(
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
- quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
- ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
+ if (use_native_mxfp4) {
+ static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
+ quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
+ ne11, ne12, ne13, stream);
+
+ } else {
+ quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
+ ne11, ne12, ne13, stream);
+ }
CUDA_CHECK(cudaGetLastError());
}
- const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
+ // Stride depends on quantization format
+ const int64_t s12 = use_native_mxfp4 ?
+ ne11 * ne10_padded * sizeof(block_fp4_mmq) /
+ (8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32)
+ :
+ ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
const int64_t s13 = ne12*s12;
const mmq_args args = {
@@ -175,12 +191,19 @@ void ggml_cuda_mul_mat_q(
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[2] / ts_src1;
- quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type,
- ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
+
+ if (use_native_mxfp4) {
+ quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
+ ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
+ } else {
+ quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
+ ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
+ }
CUDA_CHECK(cudaGetLastError());
}
- const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
+ const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
+ ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
const int64_t s13 = ne12*s12;
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index fa8a72c9c1..63451ffab7 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -11,6 +11,7 @@ using namespace ggml_cuda_mma;
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
#define MMQ_ITER_K 256
+#define MMQ_ITER_K_MXFP4_FP4 512
#define MMQ_NWARPS 8
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
@@ -44,8 +45,15 @@ struct block_q8_1_mmq {
};
int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
};
+
+struct block_fp4_mmq {
+ uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
+ int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
+};
+
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
+static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected block_fp4_mmq size");
static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
switch (type_x) {
@@ -129,6 +137,14 @@ static int get_mmq_y_host(const int cc) {
((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
}
+static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
+#if defined(BLACKWELL_MMA_AVAILABLE)
+ return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
+#else
+ return MMQ_ITER_K;
+#endif // defined(BLACKWELL_MMA_AVAILABLE)
+}
+
static constexpr __device__ int get_mmq_y_device() {
#if defined(GGML_USE_HIP)
#if defined(RDNA1)
@@ -191,6 +207,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
}
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
+#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4)
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
@@ -201,6 +218,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
switch (type) {
@@ -209,6 +228,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
+ // tile sizes are the same for Q8_1 and FP4 for blackwell
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
@@ -228,7 +248,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
}
// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
-#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
+#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_1)
+#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
@@ -761,6 +782,50 @@ template static __device__ __forceinline__ void loa
}
}
+template
+static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x,
+ int * __restrict__ x_tile,
+ const int kbx0,
+ const int i_max,
+ const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ int * x_qs = (int *) x_tile;
+ uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
+
+ const int txi = threadIdx.x;
+
+ constexpr int iter_k = get_iter_k(GGML_TYPE_MXFP4);
+
+ constexpr int threads_per_row = iter_k / QK_MXFP4; // each thread processes 1 block
+ constexpr int rows_per_warp = warp_size / threads_per_row;
+ const int kbx = txi % threads_per_row;
+ const int row_in_warp = txi / threads_per_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
+ int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
+
+ if constexpr (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
+
+ // quantize_mxfp4_mmq permutes nibbles to match the quantized format
+ const int k0 = kbx * 4;
+ memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
+
+ // Load E8M0 scales: pack 2 consecutive scales into one uint32
+ if (kbx % 2 == 0) {
+ uint32_t e = bxi->e;
+ e |= ((bxi + 1)->e << 8);
+ x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
+ }
+ }
+}
+
template
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -931,6 +996,78 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
+template
+static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
+ const int * __restrict__ y,
+ float * __restrict__ sum,
+ const int k00) {
+ typedef tile<16, 8, int> tile_A;
+ typedef tile<8, 8, int> tile_B;
+ typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
+
+ // Match layout from load_tiles_mxfp4_fp4
+ const int * x_qs = (const int *) x;
+ const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
+ const int * y_qs = (const int *) y + 4;
+ const uint32_t * y_sc = (const uint32_t *) y;
+
+ // tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
+ tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
+ uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
+
+ // Block scale
+ // Each thread has to point to a 4 byte scale value
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
+
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
+ const int k0 = k00 + k01;
+
+ load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
+ MMQ_MMA_TILE_X_K_FP4);
+
+ // based on block-scaling document, 2 threads in each quad need to supply to the scale value
+ const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
+ scaleA[n][k01 / (2 * QI_MXFP4)] =
+ *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
+ }
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
+ tile_B B;
+ uint32_t scaleB; // 2xN scales
+
+ load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
+
+ scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C C;
+
+ mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
+ }
+ }
+ }
+ }
+}
+
template
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -3109,8 +3246,13 @@ struct mmq_type_traits {
template
struct mmq_type_traits {
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
+#ifdef BLACKWELL_MMA_AVAILABLE
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma;
+#else
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma;
+#endif // BLACKWELL_MMA_AVAILABLE
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
};
@@ -3243,17 +3385,26 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
- constexpr int blocks_per_iter = MMQ_ITER_K / qk;
+#if defined(BLACKWELL_MMA_AVAILABLE)
+ // FP4 tile stores 8 blocks
+ constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
+#else
+ constexpr int ne_block = 4 * QK8_1;
+#endif // defined(BLACKWELL_MMA_AVAILABLE)
+
+ constexpr int ITER_K = get_iter_k(type);
+ constexpr int blocks_per_iter = ITER_K / qk;
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
+ constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int);
+
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
-
{
- const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
+ const int * by0 = y + ncols_y * (kb0 * qk / ne_block) * sz;
#pragma unroll
- for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
+ for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
tile_y[l] = by0[l];
@@ -3267,9 +3418,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
__syncthreads();
{
- const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
+ const int * by0 = y + ncols_y * ((kb0 * qk / ne_block) * sz + sz);
#pragma unroll
- for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
+ for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
tile_y[l] = by0[l];
@@ -3401,8 +3552,10 @@ static __global__ void mul_mat_q(
}
#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
+ constexpr int ITER_K = get_iter_k(type);
+
const int64_t blocks_per_ne00 = ncols_x / qk;
- constexpr int blocks_per_iter = MMQ_ITER_K / qk;
+ constexpr int blocks_per_iter = ITER_K / qk;
// kbc == k block continuous, current index in continuous ijk space.
int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
@@ -3463,7 +3616,7 @@ static __global__ void mul_mat_q(
__syncthreads();
}
- offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
+ offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
offset_dst += it*mmq_y;
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
@@ -3530,7 +3683,7 @@ static __global__ void mul_mat_q(
__syncthreads();
}
- offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
+ offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
offset_dst += it*mmq_y;
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
@@ -3553,7 +3706,9 @@ static __global__ void mul_mat_q_stream_k_fixup(
const int ncols_max) {
constexpr int mmq_y = get_mmq_y_device();
constexpr int qk = ggml_cuda_type_traits::qk;
- constexpr int blocks_per_iter = MMQ_ITER_K / qk;
+ constexpr int ITER_K = get_iter_k(type);
+
+ constexpr int blocks_per_iter = ITER_K / qk;
const int64_t blocks_per_ne00 = ncols_x / qk;
constexpr int nwarps = mmq_get_nwarps_device();
@@ -3711,7 +3866,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
const size_t nbs_ids = mmq_x*sizeof(int);
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
- const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
+ const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq));
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
}
diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu
index 5117f9ffc0..a8c68e44b1 100644
--- a/ggml/src/ggml-cuda/quantize.cu
+++ b/ggml/src/ggml-cuda/quantize.cu
@@ -47,6 +47,131 @@ static __global__ void quantize_q8_1(
y[ib].ds = make_half2(d, sum);
}
+__device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) {
+ if (!(amax > 0.0f)) {
+ return 0;
+ }
+
+ // FP4 E2M1: max exponent (unbiased) is 2.
+ constexpr int FP4_E2M1_EMAX = 2;
+
+ const float e = log2f(amax);
+
+ // "even" -> round-to-nearest integer, ties-to-even
+ const int e_int = __float2int_rn(e);
+
+ const int shared_exp = e_int - FP4_E2M1_EMAX;
+
+ int biased = shared_exp + 127;
+
+ biased = max(biased, 0);
+ biased = min(biased, 254);
+
+ return static_cast(biased);
+}
+
+// quantize values in the format mxfp4 is stored which is interleaved nibbles
+// i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31
+static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
+ const int32_t * __restrict__ ids,
+ void * __restrict__ vy,
+ const int64_t ne00,
+ const int64_t s01,
+ const int64_t s02,
+ const int64_t s03,
+ const int64_t ne0,
+ const int ne1,
+ const int ne2) {
+ constexpr int vals_per_scale = 32;
+ constexpr int vals_per_warp = 2 * vals_per_scale; // Each warp processes 2 blocks of 32 = 64 values
+
+ const int warp_id = threadIdx.y;
+ const int lane_id_32 = threadIdx.x;
+
+ const int nwarps = blockDim.y;
+
+ const int64_t warp_start_offset = (blockIdx.y * nwarps + warp_id) * vals_per_warp;
+
+ if (warp_start_offset >= ne0) {
+ return;
+ }
+
+ const int64_t i1 = blockIdx.x;
+ const int64_t i2 = blockIdx.z % ne2;
+ const int64_t i3 = blockIdx.z / ne2;
+
+ const int64_t i01 = ids ? ids[i1] : i1;
+ const int64_t i02 = i2;
+ const int64_t i03 = i3;
+
+ block_fp4_mmq * y = (block_fp4_mmq *) vy;
+
+ const int64_t block_fp4_mmq_size = 8 * QK_MXFP4; // 256 values
+ const int64_t ib0 = blockIdx.z * ((int64_t) ne1 * (ne0 / block_fp4_mmq_size));
+ const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x;
+ const int64_t quad_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp;
+
+ const int group_id = lane_id_32 / 4;
+ const int lane_in_group = lane_id_32 % 4;
+ const int base = group_id * 2;
+ char2 * yqs2 = (char2 *) y[ib].qs;
+
+ const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01;
+
+ uint8_t scales[2];
+
+#pragma unroll
+ for (int b = 0; b < 2; ++b) {
+ const int64_t i0 = warp_start_offset + b * vals_per_scale + lane_id_32;
+ const float xi = (i0 < ne00) ? x[base_pos + i0] : 0.0f;
+
+ float amax = fabsf(xi);
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
+ }
+
+ const uint8_t e = compute_e8m0_scale(amax);
+ scales[b] = e;
+ const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e));
+
+#if CUDART_VERSION >= 12080
+ const float scaled_val = xi * inv_s;
+
+ const float val0 = __shfl_sync(0xFFFFFFFF, scaled_val, base, WARP_SIZE);
+ const float val1 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 16, WARP_SIZE);
+ const float val2 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 1, WARP_SIZE);
+ const float val3 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 17, WARP_SIZE);
+
+ if (lane_in_group == 0) {
+ __nv_fp4x4_e2m1 fp4_packed(make_float4(val0, val1, val2, val3));
+
+ yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = *(char2 *) &fp4_packed;
+ }
+#else
+ // Fallback: manual FP4 conversion using LUT
+ const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);
+
+ const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base, WARP_SIZE);
+ const uint8_t q_lo_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 1, WARP_SIZE);
+ const uint8_t q_hi_0 = __shfl_sync(0xFFFFFFFF, q_val, base + 16, WARP_SIZE);
+ const uint8_t q_hi_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 17, WARP_SIZE);
+
+ if (lane_in_group == 0) {
+ char2 q;
+ q.x = (q_hi_0 << 4) | q_lo_0;
+ q.y = (q_hi_1 << 4) | q_lo_1;
+ yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = q;
+ }
+#endif // CUDART_VERSION >= 12080
+ }
+
+ if (lane_id_32 == 0) {
+ // Store 2 scales packed into 1 uint32
+ y[ib].d4[quad_idx_in_block] = (scales[1] << 8) | scales[0];
+ }
+}
+
template
static __global__ void quantize_mmq_q8_1(
const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
@@ -190,3 +315,29 @@ void quantize_mmq_q8_1_cuda(
break;
}
}
+
+void quantize_mmq_mxfp4_cuda(const float * x,
+ const int32_t * ids,
+ void * vy,
+ [[maybe_unused]] const ggml_type type_src0,
+ const int64_t ne00,
+ const int64_t s01,
+ const int64_t s02,
+ const int64_t s03,
+ const int64_t ne0,
+ const int64_t ne1,
+ const int64_t ne2,
+ const int64_t ne3,
+ cudaStream_t stream) {
+ GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
+
+ constexpr int nwarps = 8;
+ constexpr int vals_per_warp = 2 * QK_MXFP4;
+ constexpr int vals_per_block = nwarps * vals_per_warp;
+
+ const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
+ const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
+ const dim3 block_size(WARP_SIZE, nwarps, 1);
+
+ quantize_mmq_mxfp4<<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
+}
diff --git a/ggml/src/ggml-cuda/quantize.cuh b/ggml/src/ggml-cuda/quantize.cuh
index 725ab52443..6a91df6357 100644
--- a/ggml/src/ggml-cuda/quantize.cuh
+++ b/ggml/src/ggml-cuda/quantize.cuh
@@ -25,3 +25,17 @@ void quantize_mmq_q8_1_cuda(
const float * x, const int32_t * ids, void * vy,
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
+
+void quantize_mmq_mxfp4_cuda(const float * x,
+ const int32_t * ids,
+ void * vy,
+ ggml_type type_src0,
+ int64_t ne00,
+ int64_t s01,
+ int64_t s02,
+ int64_t s03,
+ int64_t ne0,
+ int64_t ne1,
+ int64_t ne2,
+ int64_t ne3,
+ cudaStream_t stream);
diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h
index 3b3086778e..ba032cfab4 100644
--- a/ggml/src/ggml-cuda/vendors/cuda.h
+++ b/ggml/src/ggml-cuda/vendors/cuda.h
@@ -10,6 +10,10 @@
#include
#endif // CUDART_VERSION >= 12050
+#if CUDART_VERSION >= 12080
+#include
+#endif // CUDART_VERSION >= 12080
+
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index fe57d4c582..80e0fd2ff8 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -24,10 +24,6 @@
#include
#endif
-#if defined(__F16C__)
-#include
-#endif
-
#ifdef __cplusplus
extern "C" {
#endif
diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp
index 639715537b..353f6a4b46 100644
--- a/ggml/src/ggml-opencl/ggml-opencl.cpp
+++ b/ggml/src/ggml-opencl/ggml-opencl.cpp
@@ -263,6 +263,32 @@ static ggml_cl_compiler_version get_adreno_cl_compiler_version(const char *drive
return { type, major, minor, patch };
}
+// cl buffer wrapper
+struct ggml_cl_buffer {
+ cl_mem buffer;
+ size_t size;
+
+ ggml_cl_buffer()
+ : buffer(nullptr), size(0) {}
+
+ ~ggml_cl_buffer() {
+ if (buffer) {
+ CL_CHECK(clReleaseMemObject(buffer));
+ }
+ }
+
+ void allocate(cl_context context, size_t new_size) {
+ if (new_size > size) {
+ size = new_size;
+ if (buffer) {
+ CL_CHECK(clReleaseMemObject(buffer));
+ }
+ cl_int err;
+ CL_CHECK((buffer = clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, &err), err));
+ }
+ }
+};
+
// Profiling
struct ProfilingInfo {
std::string op_name;
@@ -376,6 +402,11 @@ struct ggml_backend_opencl_context {
cl_context context;
cl_command_queue queue;
+ // prealloc buffers for transposing weights and activations
+ ggml_cl_buffer prealloc_quant_trans;
+ ggml_cl_buffer prealloc_scales_trans;
+ ggml_cl_buffer prealloc_act_trans;
+
cl_program program_add;
cl_program program_add_id;
cl_program program_clamp;
@@ -638,10 +669,6 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_transpose_16_buf;
cl_kernel kernel_transpose_16_4x1;
- cl_mem A_s_d_max; // max scale buffer size for transpose
- cl_mem A_q_d_max; // max weight buffer size for transpose
- cl_mem B_d_max; // max activation buffer size for transpose
-
// Gemm and Gemv related programs, kernels, etc
cl_program program_CL_gemm;
cl_program program_CL_gemv_general;
@@ -2600,9 +2627,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
required_B_d_bytes, max_B_d_bytes);
}
- CL_CHECK((backend_ctx->A_q_d_max = clCreateBuffer(context, 0, max_A_q_d_bytes, NULL, &err), err));
- CL_CHECK((backend_ctx->A_s_d_max = clCreateBuffer(context, 0, max_A_s_d_bytes, NULL, &err), err));
- CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err));
+ backend_ctx->prealloc_quant_trans.allocate(context, max_A_q_d_bytes);
+ backend_ctx->prealloc_scales_trans.allocate(context, max_A_s_d_bytes);
+ backend_ctx->prealloc_act_trans.allocate(context, max_B_d_bytes);
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr;
@@ -3607,32 +3634,35 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
// use sub_buffer of max buffer size instead
size_t q_size_bytes = K * M / 8 * sizeof(float);
+ backend_ctx->prealloc_quant_trans.allocate(context, q_size_bytes);
+
cl_buffer_region region;
region.origin = 0;
region.size = q_size_bytes;
cl_mem qT_d = clCreateSubBuffer(
- backend_ctx->A_q_d_max,
+ backend_ctx->prealloc_quant_trans.buffer,
0,
CL_BUFFER_CREATE_TYPE_REGION,
®ion,
&err);
- // cl_mem qT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, q_size_bytes, NULL, &err);
CL_CHECK(err);
bool K_tile_trans = true;
if ((K / 32) % 4 != 0){
K_tile_trans =false;
}
+
size_t d_size_bytes = M * (K / 32) * 2;
+ backend_ctx->prealloc_scales_trans.allocate(context, d_size_bytes);
+
region.origin = 0;
region.size = d_size_bytes;
cl_mem dT_d = clCreateSubBuffer(
- backend_ctx->A_s_d_max,
+ backend_ctx->prealloc_scales_trans.buffer,
0,
CL_BUFFER_CREATE_TYPE_REGION,
®ion,
&err);
- // cl_mem dT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, d_size_bytes, NULL, &err);
CL_CHECK(err);
// <----------------------------------------------------------------------------------> //
@@ -7395,8 +7425,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
region.origin = 0;
// Specify the size of the sub-buffer (divide by 2 for FP16)
region.size = K * (N + padding) * sizeof(float)/2;
+ backend_ctx->prealloc_act_trans.allocate(context, region.size);
+
B_d = clCreateSubBuffer(
- backend_ctx->B_d_max,
+ backend_ctx->prealloc_act_trans.buffer,
0,
CL_BUFFER_CREATE_TYPE_REGION,
®ion,
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index a524adbe0c..493ee9c9a4 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -379,18 +379,18 @@ enum FaCodePath {
};
struct vk_fa_pipeline_state {
- vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc)
- : HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {}
+ vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc)
+ : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {}
uint32_t HSK, HSV;
- bool small_rows;
+ bool small_rows, small_cache;
FaCodePath path;
bool aligned;
bool f32acc;
bool operator<(const vk_fa_pipeline_state &b) const {
- return std::tie(HSK, HSV, small_rows, path, aligned, f32acc) <
- std::tie(b.HSK, b.HSV, b.small_rows, b.path, b.aligned, b.f32acc);
+ return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) <
+ std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc);
}
};
@@ -651,7 +651,7 @@ struct vk_device_struct {
vk_pipeline pipeline_add_id_f32;
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
- vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32;
+ vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32;
vk_pipeline pipeline_scale_f32;
vk_pipeline pipeline_sqr_f32;
vk_pipeline pipeline_sqrt_f32;
@@ -763,6 +763,7 @@ struct vk_device_struct {
std::map pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
vk_pipeline pipeline_flash_attn_split_k_reduce;
+ vk_pipeline pipeline_count_experts;
// [2] is for whether to take n_experts from spec constant (0) or push constant (1)
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2];
@@ -1004,6 +1005,14 @@ struct vk_op_push_constants {
float param4;
};
+struct vk_op_count_experts_push_constants {
+ uint32_t ne00;
+ uint32_t ne01;
+ uint32_t nb00;
+ uint32_t nb01;
+ uint32_t a_offset;
+};
+
struct vk_op_glu_push_constants {
uint32_t N;
uint32_t ne00;
@@ -1192,6 +1201,7 @@ struct vk_op_diag_mask_push_constants {
struct vk_op_rope_push_constants {
uint32_t rope_mode;
uint32_t ncols;
+ uint32_t nrows;
uint32_t n_dims;
float freq_scale;
uint32_t p_delta_rows;
@@ -1564,7 +1574,7 @@ class vk_perf_logger {
total_op_times += time;
}
std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0)
- << " us";
+ << " us = " << (total_op_times / 1000.0) << " us";
// If we have as many flops entries as timing entries for the op, then compute and log the flops/S.
auto it = flops.find(t.first);
@@ -2582,10 +2592,10 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events
static constexpr uint32_t flash_attention_num_small_rows = 32;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
-static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) {
+static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) {
if (hsv >= 192) {
return 2;
- } else if ((hsv | hsk) & 8) {
+ } else if ((hsv | hsk) & 8 || small_cache) {
return 4;
} else {
return 8;
@@ -2607,9 +2617,8 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
}
}
-static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
+static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) {
GGML_UNUSED(clamp);
- GGML_UNUSED(hsv);
if (path == FA_SCALAR) {
if (small_rows) {
@@ -2618,9 +2627,9 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
if ((hsv | hsk) & 8) {
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
- return {get_fa_scalar_num_large_rows(hsk, hsv), 64};
+ return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64};
} else {
- return {get_fa_scalar_num_large_rows(hsk, hsv), 32};
+ return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32};
}
}
}
@@ -2649,8 +2658,8 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
return {64, 64};
}
-static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) {
- return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1];
+static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) {
+ return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1];
}
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) {
@@ -2830,9 +2839,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
s_mmq_wg_denoms_k = { 32, 64, 1 };
// spec constants and tile sizes for quant matmul_id
- l_warptile_mmqid = { 256, 128, 128, 16, 1, device->subgroup_size };
- m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
- s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
+ l_warptile_mmqid = { 256, 128, 128, 32, 1, device->subgroup_size };
+ m_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
+ s_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
l_mmqid_wg_denoms = { 128, 128, 1 };
m_mmqid_wg_denoms = { 128, 64, 1 };
s_mmqid_wg_denoms = { 128, 64, 1 };
@@ -2992,11 +3001,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
align, disable_robustness, require_full_subgroups, required_subgroup_size);
};
- auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array {
- return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
+ auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array {
+ return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
};
- auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector {
+ auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector {
// For large number of rows, 128 invocations seems to work best.
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
// can't use 256 for D==80.
@@ -3006,7 +3015,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
? scalar_flash_attention_workgroup_size
: ((small_rows && (D % 32) == 0) ? 256 : 128);
- auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
+ auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache);
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
@@ -3021,21 +3030,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
uint32_t HSK = fa.first.HSK; \
uint32_t HSV = fa.first.HSV; \
bool small_rows = fa.first.small_rows; \
+ bool small_cache = fa.first.small_cache; \
FaCodePath path = fa.first.path; \
bool aligned = fa.first.aligned; \
bool f32acc = fa.first.f32acc; \
if (path == FAPATH) { \
if (aligned) { \
if (f32acc) { \
- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} else { \
- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} \
} else { \
if (f32acc) { \
- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} else { \
- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} \
} \
} \
@@ -3067,17 +3077,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
#endif
#undef CREATE_FA
+ const int mul_mat_id_param_count = 5;
+
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true); \
// Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
@@ -3113,32 +3125,32 @@ static void ggml_vk_load_shaders(vk_device& device) {
GGML_ASSERT(device->subgroup_ballot);
- CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
if (device->coopmat_bf16_support) {
- CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)
}
#endif
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
#undef CREATE_MM
#undef CREATE_MM2
} else
@@ -3227,35 +3239,35 @@ static void ggml_vk_load_shaders(vk_device& device) {
GGML_ASSERT(device->subgroup_ballot);
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id);
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
if (device->coopmat_bf16_support) {
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id);
}
#endif
- CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
#undef CREATE_MM2
#undef CREATE_MM
} else
@@ -3340,91 +3352,91 @@ static void ggml_vk_load_shaders(vk_device& device) {
#endif
if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
- CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
- CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
- CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
- CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
- CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
- CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
- CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
- CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
}
#endif
} else {
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
- CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
- CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
- CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
- CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
}
#endif
}
@@ -3501,57 +3513,57 @@ static void ggml_vk_load_shaders(vk_device& device) {
#endif
if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
- CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
- CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
- CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
- CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
} else {
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
- CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
- CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
}
}
// reusing CREATE_MM from the fp32 path
@@ -3570,7 +3582,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
s_wg_denoms = { 32, 32, 1 };
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
}
#undef CREATE_MM
@@ -3955,6 +3967,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
ggml_vk_create_pipeline(device, device->pipeline_upscale_bicubic_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BICUBIC}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_antialias_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS}, 1);
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -4126,6 +4139,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_count_experts, "count_experts", count_experts_len, count_experts_data, "main", 2, sizeof(vk_op_count_experts_push_constants), {1, 1, 1}, {}, 1, true);
+
for (auto &s : device->pipeline_solve_tri_f32) {
const vk_solve_tri_pipeline_state &state = s.first;
@@ -6523,18 +6538,18 @@ static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context *
static void ggml_vk_matmul_id(
ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
- vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
+ vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, const vk_subbuffer & expert_count_buf,
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
uint32_t padded_n) {
- VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
+ VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), expert_count: (" << expert_count_buf.buffer->buffer << ", " << expert_count_buf.offset << ", " << expert_count_buf.size << "), " <<
"m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
"batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
"n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
nei0, nei1, nbi1, ne11, padded_n };
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as });
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids, expert_count_buf }, pc, { m, nei1, n_as });
}
static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
@@ -7517,6 +7532,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t nei0 = ids->ne[0];
const uint64_t nei1 = ids->ne[1];
+ const uint32_t nbi0 = ids->nb[0];
const uint32_t nbi1 = ids->nb[1];
const uint32_t nbi2 = ids->nb[2];
@@ -7624,6 +7640,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (quantize_y) {
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
}
+ vk_pipeline count_experts = ctx->device->pipeline_count_experts;
+
+ uint32_t expert_count_size = sizeof(uint32_t) * n_as;
{
if (
@@ -7639,6 +7658,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
ctx->prealloc_size_y = y_sz;
ggml_vk_preallocate_buffers(ctx, subctx);
}
+ if (ctx->prealloc_size_split_k < expert_count_size) {
+ ctx->prealloc_size_split_k = expert_count_size;
+ ggml_vk_preallocate_buffers(ctx, subctx);
+ }
// Request descriptor sets
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
@@ -7651,6 +7674,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (quantize_y) {
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
}
+ ggml_pipeline_request_descriptor_sets(ctx, count_experts, 1);
}
vk_buffer d_D = dst_buf_ctx->dev_buffer;
@@ -7700,6 +7724,20 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
ggml_vk_sync_buffers(ctx, subctx);
}
}
+ // Count how many times each expert is used
+ vk_subbuffer expert_count_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
+ if (ctx->prealloc_split_k_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
+ {
+ const std::vector pc = { (uint32_t)nei0,
+ (uint32_t)nei1,
+ (uint32_t)(nbi0 / ggml_type_size(ids->type)),
+ (uint32_t)(nbi1 / ggml_type_size(ids->type)),
+ (uint32_t)(get_misalign_bytes(ctx, ids) / ggml_type_size(ids->type)) };
+ ggml_vk_dispatch_pipeline(ctx, subctx, count_experts,
+ { vk_subbuffer{ d_ids, ids_buf_offset, ids_sz }, expert_count_buf }, pc, { (uint32_t)n_as, 1, 1});
+ }
if (x_non_contig) {
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
@@ -7707,7 +7745,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_X, 0, x_sz } }, pc, { (uint32_t)x_ne, 1, 1});
- ggml_vk_sync_buffers(ctx, subctx);
}
if (y_non_contig) {
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
@@ -7731,6 +7768,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
ctx->prealloc_y_last_tensor_used = src1;
}
}
+ ggml_vk_sync_buffers(ctx, subctx);
uint32_t stride_batch_x = ne00*ne01;
uint32_t stride_batch_y = ne10*ne11;
@@ -7747,7 +7785,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
ggml_vk_matmul_id(
ctx, subctx, pipeline,
{ d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz },
- { d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz },
+ { d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, expert_count_buf,
ne01, ne21, ne10, ne10, ne10, ne01,
stride_batch_x, stride_batch_y, ne20*ne21,
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
@@ -7759,6 +7797,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (y_non_contig || quantize_y) {
ctx->prealloc_y_need_sync = true;
}
+ ctx->prealloc_split_k_need_sync = true;
}
static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
@@ -8008,11 +8047,11 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
-static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
+static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) {
// Needs to be kept up to date on shader changes
GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
- const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv);
+ const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache);
const uint32_t Bc = scalar_flash_attention_Bc;
const uint32_t tmpsh = wg_size * sizeof(float);
@@ -8136,6 +8175,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
uint32_t workgroups_y = (uint32_t)neq2;
uint32_t workgroups_z = (uint32_t)neq3;
+ const bool small_cache = nek1 < 1024;
+
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
uint32_t max_gqa;
@@ -8143,7 +8184,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
case FA_SCALAR:
case FA_COOPMAT1:
// We may switch from coopmat1 to scalar, so use the scalar limit for both
- max_gqa = get_fa_scalar_num_large_rows(HSK, HSV);
+ max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache);
break;
case FA_COOPMAT2:
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
@@ -8177,7 +8218,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
if (path == FA_SCALAR &&
- !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
+ !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) {
small_rows = true;
}
@@ -8193,7 +8234,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
v_stride /= 4;
}
- uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
+ uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache);
bool aligned = (KV % alignment) == 0 &&
// the "aligned" shader variant will forcibly align strides, for performance
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
@@ -8205,7 +8246,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
- vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc);
+ vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc);
vk_pipeline pipeline = nullptr;
@@ -8430,7 +8471,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
case GGML_OP_UPSCALE:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(dst, 0) & 0xFF);
+ uint32_t mode = (ggml_get_op_params_i32(dst, 0) & (0xFF | GGML_SCALE_FLAG_ANTIALIAS));
switch (mode) {
case GGML_SCALE_MODE_NEAREST:
return ctx->device->pipeline_upscale_nearest_f32;
@@ -8438,6 +8479,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_upscale_bilinear_f32;
case GGML_SCALE_MODE_BICUBIC:
return ctx->device->pipeline_upscale_bicubic_f32;
+ case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS:
+ return ctx->device->pipeline_upscale_bilinear_antialias_f32;
default:
return nullptr;
}
@@ -9088,10 +9131,20 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
} break;
case GGML_OP_DIAG_MASK_INF:
- case GGML_OP_ROPE:
- case GGML_OP_ROPE_BACK:
elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
break;
+ case GGML_OP_ROPE:
+ case GGML_OP_ROPE_BACK:
+ {
+ uint32_t nrows = (uint32_t)ggml_nrows(src0);
+ uint32_t z = 1;
+ if (nrows > ctx->device->properties.limits.maxComputeWorkGroupCount[0]) {
+ z = CEIL_DIV(nrows, 32768);
+ nrows = 32768;
+ }
+ elements = { nrows, (uint32_t)ne00, z };
+
+ } break;
case GGML_OP_GET_ROWS:
elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
@@ -10019,7 +10072,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *
uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
vk_op_rope_push_constants rope {
- (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
+ (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
has_ff, (uint32_t)src0->ne[2], nb01, nb02,
{ sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
@@ -13716,6 +13769,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
}
static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
+ VK_LOG_DEBUG("ggml_backend_vk_event_record(backend=" << backend << ", event=" << event << ")");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
vk_event *vkev = (vk_event *)event->context;
@@ -13745,6 +13799,7 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev
}
static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
+ VK_LOG_DEBUG("ggml_backend_vk_event_wait(backend=" << backend << ", event=" << event << ")");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
vk_event *vkev = (vk_event *)event->context;
@@ -13760,6 +13815,8 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even
}
ggml_vk_wait_events(transfer_ctx, {vkev->event});
+ ggml_vk_ctx_end(transfer_ctx);
+ ctx->transfer_ctx.reset();
}
// TODO: enable async and synchronize
@@ -14324,7 +14381,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
}
return true;
case GGML_OP_UPSCALE:
- return op->src[0]->type == GGML_TYPE_F32 && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
+ if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) {
+ if ((op->op_params[0] & 0xFF) != GGML_SCALE_MODE_BILINEAR) {
+ return false;
+ }
+ }
+ return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_CONCAT:
@@ -14519,6 +14581,7 @@ static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backe
}
static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
+ VK_LOG_DEBUG("ggml_backend_vk_device_event_synchronize(backend=" << dev << ", event=" << event << ")");
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
vk_event *vkev = (vk_event *)event->context;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp b/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp
new file mode 100644
index 0000000000..ffc8608691
--- /dev/null
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp
@@ -0,0 +1,51 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+
+#include "types.glsl"
+
+layout (push_constant) uniform parameter
+{
+ uint32_t ne00;
+ uint32_t ne01;
+ uint32_t nb00;
+ uint32_t nb01;
+ uint32_t a_offset;
+} p;
+
+#define BLOCK_SIZE 256
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {uint data_a[];};
+layout (binding = 1) writeonly buffer D {uint data_d[];};
+
+shared uint vals[BLOCK_SIZE];
+
+void main() {
+ const uint expert_id = gl_WorkGroupID.x;
+ const uint num_elements = p.ne00 * p.ne01;
+ const uint tid = gl_LocalInvocationID.x;
+
+ uint count = 0;
+ for (uint idx = tid; idx < num_elements; idx += BLOCK_SIZE) {
+ const uint i01 = idx / p.ne00;
+ const uint i00 = idx % p.ne00;
+ const uint a = data_a[p.a_offset + i01 * p.nb01 + i00 * p.nb00];
+
+ count += uint(a == expert_id);
+ }
+
+ vals[tid] = count;
+ barrier();
+ [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ vals[tid] += vals[tid + s];
+ }
+ barrier();
+ }
+
+ if (tid == 0) {
+ data_d[expert_id] = vals[0];
+ }
+}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
index 70ee542d96..376944f1e2 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
@@ -401,13 +401,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;
const uint qshift = (iqs & 16) >> 2;
- u8vec4 qs = u8vec4(
- data_a[a_offset + ib].qs[iq + 0],
- data_a[a_offset + ib].qs[iq + 1],
- data_a[a_offset + ib].qs[iq + 2],
- data_a[a_offset + ib].qs[iq + 3]
- );
- qs = (qs >> qshift) & uint8_t(0xF);
+ const u8vec4 qs = unpack8((data_a_packed32[a_offset + ib].qs[iq/4] >> qshift) & 0x0F0F0F0F);
const float dl = float(int(sl | (sh << 4)) - 32);
return dl * vec4(
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
index 5c5251da39..c0c00d28fc 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
@@ -68,6 +68,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
+layout (binding = 4) readonly buffer Counts {int data_expert_count[];};
#endif
layout (push_constant) uniform parameter
@@ -135,13 +136,19 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#include "mul_mm_funcs.glsl"
void main() {
+ const uint ic = gl_WorkGroupID.y;
+
+#ifdef MUL_MAT_ID
+ const uint expert_idx = gl_GlobalInvocationID.z;
+ if (ic * BN >= data_expert_count[expert_idx]) {
+ return;
+ }
+#endif
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
-#ifdef MUL_MAT_ID
- const uint expert_idx = gl_GlobalInvocationID.z;
-#else
+#ifndef MUL_MAT_ID
const uint batch_idx = gl_GlobalInvocationID.z;
const uint i13 = batch_idx / p.ne12;
@@ -156,7 +163,6 @@ void main() {
const uint blocks_m = (p.M + BM - 1) / BM;
const uint ir = gl_WorkGroupID.x % blocks_m;
const uint ik = gl_WorkGroupID.x / blocks_m;
- const uint ic = gl_WorkGroupID.y;
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const uint WSUBM = WM / WMITER;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
index 2e04baa44e..d0d1d8ef72 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
@@ -92,6 +92,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
+layout (binding = 4) readonly buffer Counts {int data_expert_count[];};
shared u16vec4 row_ids[BN];
@@ -107,11 +108,7 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i
{
const uint row_i = blockCoords[0];
- if (row_i >= _ne1) {
- return B_TYPE(0.0);
- }
-
- const u16vec4 row_idx = row_ids[row_i & (BN - 1)];
+ const u16vec4 row_idx = row_ids[row_i];
B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
return ret;
@@ -138,6 +135,8 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
uint ids[16];
uint iter = 0;
+ uint expert_count = data_expert_count[expert_idx];
+
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
// prefetch up to 16 elements
if (iter == 0) {
@@ -185,7 +184,7 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
}
_ne1 += total;
iter &= 15;
- if (_ne1 >= (ic + 1) * BN) {
+ if (_ne1 >= (ic + 1) * BN || _ne1 == expert_count) {
break;
}
}
@@ -194,15 +193,28 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
#endif
void main() {
+ const uint tid = gl_LocalInvocationIndex;
+ const uint ic = gl_WorkGroupID.y;
+
+#ifdef MUL_MAT_ID
+ const uint expert_idx = gl_GlobalInvocationID.z;
+ if (ic * BN >= data_expert_count[expert_idx]) {
+ return;
+ }
+ // initialize to row 0 so we don't need to bounds check
+ if (tid < BN) {
+ row_ids[tid] = u16vec4(0);
+ }
+#if !defined(NEEDS_INIT_IQ_SHMEM)
+ barrier();
+#endif
+#endif
+
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
- const uint tid = gl_LocalInvocationIndex;
-
-#ifdef MUL_MAT_ID
- const uint expert_idx = gl_GlobalInvocationID.z;
-#else
+#ifndef MUL_MAT_ID
const uint batch_idx = gl_GlobalInvocationID.z;
const uint i13 = batch_idx / p.ne12;
@@ -217,7 +229,6 @@ void main() {
const uint blocks_m = (p.M + BM - 1) / BM;
const uint ir = gl_WorkGroupID.x % blocks_m;
const uint ik = gl_WorkGroupID.x / blocks_m;
- const uint ic = gl_WorkGroupID.y;
#ifdef MUL_MAT_ID
if (bitCount(p.nei0) == 1) {
@@ -482,7 +493,7 @@ void main() {
coopmat mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else {
@@ -490,7 +501,7 @@ void main() {
coopmat mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
@@ -526,7 +537,7 @@ void main() {
coopmat mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else {
@@ -534,7 +545,7 @@ void main() {
coopmat mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
@@ -571,7 +582,7 @@ void main() {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
#else
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
#endif
@@ -583,7 +594,7 @@ void main() {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
#else
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
#endif
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
index 58ede04400..1a3531761a 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
@@ -159,14 +159,16 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint is = iqs / 8; // 0..15
const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3
const uint qsshift = halfsplit * 2; // 0,2,4,6
- const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)
| (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
const float dl = float(data_a[ib].d) * float(us - 32);
- buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)),
- dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
+ const vec2 qs = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> qsshift) & 0x0303).xy);
+ const vec2 hm = vec2(unpack8(((uint(data_a_packed16[ib].hmask[hmi / 2]) >> (4 * n + halfsplit)) & 0x0101 ^ 0x0101) << 2).xy);
+
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * (qs.x - hm.x),
+ dl * (qs.y - hm.y));
#elif defined(DATA_A_Q4_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -198,8 +200,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
- buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m),
- fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
+ const vec2 q = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F).xy);
+
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
+ fma(d, q.y, m));
#elif defined(DATA_A_Q5_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -213,8 +217,6 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
const uint qhi = (iqs % 16) * 2; // 0,2,4..30
- const uint8_t hm = uint8_t(1 << (iqs / 16));
-
const vec2 loadd = vec2(data_a[ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
@@ -234,8 +236,12 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
- buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m),
- fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
+ const uint qs = (uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F;
+ const uint qh = ((uint(data_a_packed16[ib].qh[qhi / 2]) >> (iqs / 16)) & 0x0101) << 4;
+ const vec2 q = vec2(unpack8(qs | qh).xy);
+
+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
+ fma(d, q.y, m));
#elif defined(DATA_A_Q6_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -394,11 +400,9 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const float d = float(data_a[ib].d);
const uint qs = data_a[ib].qs[iqs];
- const uint signs = pack32(u8vec4(
- data_a[ib].qs[is+0],
- data_a[ib].qs[is+1],
- data_a[ib].qs[is+2],
- data_a[ib].qs[is+3]
+ const uint signs = pack32(u16vec2(
+ data_a_packed16[ib].qs[is/2],
+ data_a_packed16[ib].qs[is/2+1]
));
const float db = d * 0.5 * (0.5 + (signs >> 28));
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
@@ -443,8 +447,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;
const uint qshift = (idx & 8) >> 1;
- u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]);
- qs = (qs >> qshift) & uint8_t(0xF);
+ u8vec2 qs = unpack8((uint(data_a_packed16[ib].qs[iq/2]) >> qshift) & 0x0F0F).xy;
const float d = float(data_a[ib].d);
const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
index 1d0e84ac94..743004ff8a 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
@@ -13,6 +13,8 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
uint ids[16];
uint iter = 0;
+ uint expert_count = data_expert_count[expert_idx];
+
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
// prefetch up to 16 elements
if (iter == 0) {
@@ -60,7 +62,7 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
}
_ne1 += total;
iter &= 15;
- if (_ne1 >= (ic + 1) * BN) {
+ if (_ne1 >= (ic + 1) * BN || _ne1 == expert_count) {
break;
}
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
index dc8b3df47b..cd36e270ab 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
@@ -35,6 +35,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
+layout (binding = 4) readonly buffer Counts {int data_expert_count[];};
#endif
layout (push_constant) uniform parameter
@@ -104,13 +105,19 @@ block_b_cache cache_b;
#include "mul_mmq_funcs.glsl"
void main() {
+ const uint ic = gl_WorkGroupID.y;
+
+#ifdef MUL_MAT_ID
+ const uint expert_idx = gl_GlobalInvocationID.z;
+ if (ic * BN >= data_expert_count[expert_idx]) {
+ return;
+ }
+#endif
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
-#ifdef MUL_MAT_ID
- const uint expert_idx = gl_GlobalInvocationID.z;
-#else
+#ifndef MUL_MAT_ID
const uint batch_idx = gl_GlobalInvocationID.z;
const uint i13 = batch_idx / p.ne12;
@@ -125,7 +132,6 @@ void main() {
const uint blocks_m = (p.M + BM - 1) / BM;
const uint ir = gl_WorkGroupID.x % blocks_m;
const uint ik = gl_WorkGroupID.x / blocks_m;
- const uint ic = gl_WorkGroupID.y;
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const uint WSUBM = WM / WMITER;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
index 7c1fb1cd22..f7587468a8 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
@@ -6,6 +6,9 @@
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
// i1 is actually i2*nb2+i1, but the rows are contiguous
- const uint i1 = gl_GlobalInvocationID.x;
+ const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+ if (i1 >= pc.nrows) {
+ return;
+ }
rope_multi(i0, i1, pc);
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
index 68f00c180b..acb8ed7815 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
@@ -6,6 +6,9 @@
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
// i1 is actually i2*nb2+i1, but the rows are contiguous
- const uint i1 = gl_GlobalInvocationID.x;
+ const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+ if (i1 >= pc.nrows) {
+ return;
+ }
rope_neox(i0, i1, pc);
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
index 28a939ec6a..0033cdb224 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
@@ -6,6 +6,9 @@
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
// i1 is actually i2*nb2+i1, but the rows are contiguous
- const uint i1 = gl_GlobalInvocationID.x;
+ const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+ if (i1 >= pc.nrows) {
+ return;
+ }
rope_norm(i0, i1, pc);
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl
index 82f39cee34..939cf3c51c 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl
@@ -6,6 +6,7 @@
struct rope_params {
uint rope_mode;
uint ncols;
+ uint nrows;
uint n_dims;
float freq_scale;
uint p_delta_rows;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp
index ea1e0fdb41..d93800b5e7 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp
@@ -6,6 +6,9 @@
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
// i1 is actually i2*nb2+i1, but the rows are contiguous
- const uint i1 = gl_GlobalInvocationID.x;
+ const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+ if (i1 >= pc.nrows) {
+ return;
+ }
rope_vision(i0, i1, pc);
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl
index 02578c77c4..402a2a8397 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl
@@ -172,16 +172,12 @@ struct block_q8_0
float16_t d;
int8_t qs[32];
};
+
struct block_q8_0_packed16
{
float16_t d;
int16_t qs[32/2];
};
-struct block_q8_0_packed32
-{
- float16_t d;
- int32_t qs[32/4];
-};
#if defined(DATA_A_Q8_0)
#define QUANT_K QUANT_K_Q8_0
@@ -189,7 +185,6 @@ struct block_q8_0_packed32
#define QUANT_AUXF 1
#define A_TYPE block_q8_0
#define A_TYPE_PACKED16 block_q8_0_packed16
-#define A_TYPE_PACKED32 block_q8_0_packed32
#define DATA_A_QUANT_LEGACY
#endif
@@ -201,11 +196,13 @@ struct block_q8_1
f16vec2 ds;
int8_t qs[32];
};
+
struct block_q8_1_packed16
{
f16vec2 ds;
int16_t qs[16];
};
+
struct block_q8_1_packed32
{
f16vec2 ds;
@@ -218,6 +215,7 @@ struct block_q8_1_x4
f16vec2 ds[4];
int32_t qs[32];
};
+
struct block_q8_1_x4_packed128
{
f16vec2 ds[4];
@@ -1346,10 +1344,28 @@ struct block_iq4_xs
uint8_t qs[QUANT_K_IQ4_XS/2];
};
+struct block_iq4_xs_packed16
+{
+ float16_t d;
+ uint16_t scales_h;
+ uint16_t scales_l[QUANT_K_IQ4_XS/128];
+ uint16_t qs[QUANT_K_IQ4_XS/4];
+};
+
+struct block_iq4_xs_packed32
+{
+ float16_t d;
+ uint16_t scales_h;
+ uint32_t scales_l;
+ uint32_t qs[QUANT_K_IQ4_XS/8];
+};
+
#if defined(DATA_A_IQ4_XS)
#define QUANT_K QUANT_K_IQ4_XS
#define QUANT_R QUANT_R_IQ4_XS
#define A_TYPE block_iq4_xs
+#define A_TYPE_PACKED16 block_iq4_xs_packed16
+#define A_TYPE_PACKED32 block_iq4_xs_packed32
#endif
#define QUANT_K_IQ4_NL 32
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp
index 037ab0c78f..f7d12a8dda 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp
@@ -21,6 +21,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
#define NEAREST 0
#define BILINEAR 1
#define BICUBIC 2
+#define BILINEAR_ANTIALIAS 513
layout (constant_id = 0) const uint scale_mode = 0;
@@ -62,6 +63,56 @@ float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
return fetch_bilinear(c0, c1, d, i12, i13);
}
+float triangle_filter(float x) {
+ return max(1.0f - abs(x), 0.0f);
+}
+
+float interpolate_bilinear_antialias(uint i10, uint i11, uint i12, uint i13) {
+ const float support1 = max(1.0f, 1.0f / p.sf1);
+ const float invscale1 = 1.0f / support1;
+ const float support0 = max(1.0f, 1.0f / p.sf0);
+ const float invscale0 = 1.0f / support0;
+
+ const uint i02 = uint(i12 / p.sf2);
+ const uint i03 = uint(i13 / p.sf3);
+
+ const float y = (float(i11) + p.pixel_offset) / p.sf1;
+ const float x = (float(i10) + p.pixel_offset) / p.sf0;
+
+ // the range of source pixels that contribute
+ const int x_min = max(int(x - support0 + p.pixel_offset), 0);
+ const int x_max = min(int(x + support0 + p.pixel_offset), int(p.ne00));
+ const int y_min = max(int(y - support1 + p.pixel_offset), 0);
+ const int y_max = min(int(y + support1 + p.pixel_offset), int(p.ne01));
+
+ // bilinear filter with antialiasing
+ float val = 0.0f;
+ float total_weight = 0.0f;
+
+ for (int sy = y_min; sy < y_max; sy++) {
+ const float weight_y = triangle_filter((sy - y + p.pixel_offset) * invscale1);
+
+ for (int sx = x_min; sx < x_max; sx++) {
+ const float weight_x = triangle_filter((sx - x + p.pixel_offset) * invscale0);
+ const float weight = weight_x * weight_y;
+
+ if (weight <= 0.0f) {
+ continue;
+ }
+
+ const float pixel = data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + sy * p.nb01 + sx * p.nb00];
+ val += pixel * weight;
+ total_weight += weight;
+ }
+ }
+
+ if (total_weight > 0.0f) {
+ val /= total_weight;
+ }
+
+ return val;
+}
+
// Bicubic interpolation with alpha = -0.75
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
const vec4 bcoeffs1 = vec4( 1.25, -2.25, 0.0, 1.0);
@@ -118,6 +169,9 @@ void main() {
case BICUBIC:
result = interpolate_bicubic(i10, i11, i12, i13);
break;
+ case BILINEAR_ANTIALIAS:
+ result = interpolate_bilinear_antialias(i10, i11, i12, i13);
+ break;
}
data_d[p.d_offset + idx] = D_TYPE(result);
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
index e237a8e102..4a83378374 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
@@ -945,6 +945,8 @@ void process_shaders() {
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+ string_to_spv("count_experts", "count_experts.comp", merge_maps(base_dict, {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}}));
+
for (std::string dim_str : {"", "_3d"}) {
for (bool bda : {false, true}) {
std::string bda_str = bda ? "_bda" : "";
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index 41d3bd4faf..27578daaf9 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -449,6 +449,8 @@ class MODEL_ARCH(IntEnum):
RND1 = auto()
PANGU_EMBED = auto()
MISTRAL3 = auto()
+ MIMO2 = auto()
+ LLAMA_EMBED = auto()
class VISION_PROJECTOR_TYPE(IntEnum):
@@ -844,6 +846,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.RND1: "rnd1",
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
MODEL_ARCH.MISTRAL3: "mistral3",
+ MODEL_ARCH.MIMO2: "mimo2",
+ MODEL_ARCH.LLAMA_EMBED: "llama-embed",
}
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -3196,6 +3200,46 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
+ MODEL_ARCH.MIMO2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_SINKS,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_EXP_PROBS_B,
+ ],
+ MODEL_ARCH.LLAMA_EMBED: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
# TODO
}
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 276720fcde..1690d991f2 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -320,6 +320,7 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_SINKS: (
"model.layers.{bid}.self_attn.sinks", # openai-moe
+ "model.layers.{bid}.self_attn.attention_sink_bias", # mimov2
),
MODEL_TENSOR.ATTN_GATE: (
diff --git a/include/llama.h b/include/llama.h
index f862930099..4f0124fdc8 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -286,7 +286,7 @@ extern "C" {
// NULL-terminated list of buffer types to use for tensors that match a pattern
const struct llama_model_tensor_buft_override * tensor_buft_overrides;
- int32_t n_gpu_layers; // number of layers to store in VRAM
+ int32_t n_gpu_layers; // number of layers to store in VRAM, a negative value means all layers
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
// the GPU that is used for the entire model when split_mode is LLAMA_SPLIT_MODE_NONE
@@ -467,10 +467,17 @@ extern "C" {
// Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx);
+ enum llama_params_fit_status {
+ LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit
+ LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit
+ LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occured, e.g. because no model could be found at the specified path
+ };
+
// fits mparams and cparams to free device memory (assumes system memory is unlimited)
- // returns true if the parameters could be successfully modified to fit device memory
- // this function is NOT thread safe because it modifies the global llama logger state
- LLAMA_API bool llama_params_fit(
+ // - returns true if the parameters could be successfully modified to fit device memory
+ // - this function is NOT thread safe because it modifies the global llama logger state
+ // - only parameters that have the same value as in llama_default_model_params are modified
+ LLAMA_API enum llama_params_fit_status llama_params_fit(
const char * path_model,
struct llama_model_params * mparams,
struct llama_context_params * cparams,
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 4ca8974916..1e155534bd 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -88,6 +88,7 @@ add_library(llama
models/llama-iswa.cpp
models/llama.cpp
models/mamba.cpp
+ models/mimo2-iswa.cpp
models/minicpm3.cpp
models/minimax-m2.cpp
models/modern-bert.cpp
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index 80f44ae1bf..75013d8d33 100644
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -115,6 +115,8 @@ static const std::map LLM_ARCH_NAMES = {
{ LLM_ARCH_RND1, "rnd1" },
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
{ LLM_ARCH_MISTRAL3, "mistral3" },
+ { LLM_ARCH_MIMO2, "mimo2" },
+ { LLM_ARCH_LLAMA_EMBED, "llama-embed" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@@ -500,6 +502,7 @@ static std::set llm_get_tensor_names(llm_arch arch) {
case LLM_ARCH_LLAMA:
case LLM_ARCH_DECI:
case LLM_ARCH_MISTRAL3:
+ case LLM_ARCH_LLAMA_EMBED:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
@@ -2188,6 +2191,27 @@ static std::set llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_VISEXP_FFN_DOWN,
LLM_TENSOR_VISEXP_FFN_UP,
};
+ case LLM_ARCH_MIMO2:
+ return {
+ LLM_TENSOR_TOKEN_EMBD,
+ LLM_TENSOR_OUTPUT_NORM,
+ LLM_TENSOR_OUTPUT,
+ LLM_TENSOR_ATTN_NORM,
+ LLM_TENSOR_ATTN_Q,
+ LLM_TENSOR_ATTN_K,
+ LLM_TENSOR_ATTN_V,
+ LLM_TENSOR_ATTN_SINKS,
+ LLM_TENSOR_ATTN_OUT,
+ LLM_TENSOR_FFN_NORM,
+ LLM_TENSOR_FFN_GATE,
+ LLM_TENSOR_FFN_DOWN,
+ LLM_TENSOR_FFN_UP,
+ LLM_TENSOR_FFN_GATE_INP,
+ LLM_TENSOR_FFN_GATE_EXPS,
+ LLM_TENSOR_FFN_DOWN_EXPS,
+ LLM_TENSOR_FFN_UP_EXPS,
+ LLM_TENSOR_FFN_EXP_PROBS_B,
+ };
case LLM_ARCH_GPTJ:
case LLM_ARCH_UNKNOWN:
return {
diff --git a/src/llama-arch.h b/src/llama-arch.h
index a53bc39d18..27bdedc83c 100644
--- a/src/llama-arch.h
+++ b/src/llama-arch.h
@@ -119,6 +119,8 @@ enum llm_arch {
LLM_ARCH_RND1,
LLM_ARCH_PANGU_EMBED,
LLM_ARCH_MISTRAL3,
+ LLM_ARCH_MIMO2,
+ LLM_ARCH_LLAMA_EMBED,
LLM_ARCH_UNKNOWN,
};
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index 015ebae71d..1c530fdc91 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -294,8 +294,8 @@ llama_context::llama_context(
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
bool pipeline_parallel =
model.n_devices() > 1 &&
- model.params.n_gpu_layers > (int) model.hparams.n_layer &&
- model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
+ model.n_gpu_layers() > model.hparams.n_layer &&
+ model.split_mode() == LLAMA_SPLIT_MODE_LAYER &&
cparams.offload_kqv &&
!model.has_tensor_overrides();
@@ -1570,7 +1570,7 @@ llm_graph_cb llama_context::graph_get_cb() const {
// norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
// FIXME: fix in ggml_backend_sched
- const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer;
+ const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer;
if (ubatch.n_tokens < 32 || full_offload) {
if (il != -1 && strcmp(name, "norm") == 0) {
const auto & dev_layer = model.dev_layer(il);
diff --git a/src/llama-hparams.h b/src/llama-hparams.h
index f6e95b5d2a..42def73f06 100644
--- a/src/llama-hparams.h
+++ b/src/llama-hparams.h
@@ -123,10 +123,11 @@ struct llama_hparams {
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
// the size of the sliding window (0 - no SWA)
uint32_t n_swa = 0;
- // if swa_layers[il] == true, then layer il is SWA
- // if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
+ // if swa_layers[il] == 1, then layer il is SWA
+ // if swa_layers[il] == 0, then layer il is dense (i.e. non-SWA)
// by default, all layers are dense
- std::array swa_layers;
+ // note: using uint32_t type for compatibility reason
+ std::array swa_layers;
// for State Space Models
uint32_t ssm_d_conv = 0;
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 0d5bcc64fe..1d6134ec05 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -130,6 +130,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_230B_A10B: return "230B.A10B";
case LLM_TYPE_235B_A22B: return "235B.A22B";
case LLM_TYPE_300B_A47B: return "300B.A47B";
+ case LLM_TYPE_310B_A15B: return "310B.A15B";
case LLM_TYPE_355B_A32B: return "355B.A32B";
case LLM_TYPE_E2B: return "E2B";
case LLM_TYPE_E4B: return "E4B";
@@ -606,7 +607,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
- if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) {
+ if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON || arch == LLM_ARCH_LLAMA_EMBED) {
if (hparams.n_rot != hparams.n_embd_head_k) {
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
}
@@ -630,6 +631,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// arch-specific KVs
switch (arch) {
case LLM_ARCH_LLAMA:
+ case LLM_ARCH_LLAMA_EMBED:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -2338,6 +2340,22 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
+ case LLM_ARCH_MIMO2:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+ hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
+
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
+ ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa);
+ ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer);
+
+ switch (hparams.n_layer) {
+ case 48: type = LLM_TYPE_310B_A15B; break;
+ default: type = LLM_TYPE_UNKNOWN;
+ }
+ } break;
default: throw std::runtime_error("unsupported model architecture");
}
@@ -2360,11 +2378,11 @@ void llama_model::load_vocab(llama_model_loader & ml) {
bool llama_model::load_tensors(llama_model_loader & ml) {
const auto & split_mode = params.split_mode;
- const auto & n_gpu_layers = params.n_gpu_layers;
const auto & use_mlock = params.use_mlock;
const auto & tensor_split = params.tensor_split;
- const int n_layer = hparams.n_layer;
+ const int n_layer = hparams.n_layer;
+ const int n_gpu_layers = this->n_gpu_layers();
const bool use_mmap_buffer = true;
@@ -2652,6 +2670,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
case LLM_ARCH_MISTRAL3:
+ case LLM_ARCH_LLAMA_EMBED:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -6646,6 +6665,44 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0);
}
} break;
+ case LLM_ARCH_MIMO2:
+ {
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ // output
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = layers[i];
+ uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i);
+ uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i);
+ uint32_t n_head = hparams.n_head(i);
+
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_v * n_head, n_embd }, 0);
+
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+ layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, TENSOR_NOT_REQUIRED);
+
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+ // non-MoE branch
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED);
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
+
+ // MoE branch
+ int64_t n_ff_exp = hparams.n_ff_exp;
+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED);
+ layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED);
+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED);
+ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED);
+ layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
+ }
+ } break;
default:
throw std::runtime_error("unknown architecture");
}
@@ -6827,6 +6884,14 @@ size_t llama_model::n_devices() const {
return devices.size();
}
+uint32_t llama_model::n_gpu_layers() const {
+ return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer + 1;
+}
+
+llama_split_mode llama_model::split_mode() const {
+ return params.split_mode;
+}
+
std::map llama_model::memory_breakdown() const {
std::map ret;
for (const auto & [ctx, bufs] : pimpl->ctxs_bufs) {
@@ -7269,16 +7334,20 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
switch (arch) {
case LLM_ARCH_LLAMA:
{
- llm = std::make_unique(*this, params);
+ llm = std::make_unique>(*this, params);
} break;
case LLM_ARCH_LLAMA4:
{
if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) {
- llm = std::make_unique(*this, params);
+ llm = std::make_unique>(*this, params);
} else {
llm = std::make_unique(*this, params);
}
} break;
+ case LLM_ARCH_LLAMA_EMBED:
+ {
+ llm = std::make_unique>(*this, params);
+ } break;
case LLM_ARCH_DECI:
{
llm = std::make_unique(*this, params);
@@ -7704,6 +7773,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique(*this, params);
} break;
+ case LLM_ARCH_MIMO2:
+ {
+ llm = std::make_unique(*this, params);
+ } break;
default:
GGML_ABORT("fatal error");
}
@@ -7729,7 +7802,7 @@ llama_model_params llama_model_default_params() {
llama_model_params result = {
/*.devices =*/ nullptr,
/*.tensor_buft_overrides =*/ nullptr,
- /*.n_gpu_layers =*/ 999,
+ /*.n_gpu_layers =*/ -1,
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
/*.main_gpu =*/ 0,
/*.tensor_split =*/ nullptr,
@@ -7874,6 +7947,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_ERNIE4_5:
case LLM_ARCH_ERNIE4_5_MOE:
case LLM_ARCH_MISTRAL3:
+ case LLM_ARCH_LLAMA_EMBED:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2
@@ -7933,6 +8007,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_PANGU_EMBED:
case LLM_ARCH_AFMOE:
case LLM_ARCH_QWEN3NEXT:
+ case LLM_ARCH_MIMO2:
return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_QWEN2VL:
diff --git a/src/llama-model.h b/src/llama-model.h
index 7f560d462f..dbe5edc153 100644
--- a/src/llama-model.h
+++ b/src/llama-model.h
@@ -123,6 +123,7 @@ enum llm_type {
LLM_TYPE_230B_A10B, // Minimax M2
LLM_TYPE_235B_A22B,
LLM_TYPE_300B_A47B, // Ernie MoE big
+ LLM_TYPE_310B_A15B, // /MiMo-V2-Flash
LLM_TYPE_355B_A32B, // GLM-4.5
LLM_TYPE_E2B,
LLM_TYPE_E4B,
@@ -465,8 +466,6 @@ struct llama_model {
struct ggml_tensor * dense_2_out_layers = nullptr;
struct ggml_tensor * dense_3_out_layers = nullptr;
- llama_model_params params;
-
// gguf metadata
std::unordered_map gguf_kv;
@@ -497,6 +496,9 @@ struct llama_model {
size_t n_tensors() const;
size_t n_devices() const;
+ uint32_t n_gpu_layers() const;
+ llama_split_mode split_mode() const;
+
std::map memory_breakdown() const;
// total number of parameters in the model
@@ -525,6 +527,8 @@ struct llama_model {
ggml_cgraph * build_graph(const llm_graph_params & params) const;
private:
+ llama_model_params params;
+
struct impl;
std::unique_ptr pimpl;
};
diff --git a/src/llama.cpp b/src/llama.cpp
index 1e18637e36..76b3acbadb 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -140,6 +140,10 @@ enum layer_fraction_t {
};
// this enum is only used in llama_params_fit_impl but needs to be defined outside of it to fix a Windows compilation issue
+class llama_params_fit_exception : public std::runtime_error {
+ using std::runtime_error::runtime_error;
+};
+
static void llama_params_fit_impl(
const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
@@ -181,12 +185,11 @@ static void llama_params_fit_impl(
}
}
- int64_t sum_total = 0;
+ int64_t sum_free = 0;
int64_t sum_projected_free = 0;
int64_t min_projected_free = INT64_MAX;
int64_t sum_projected_used = 0;
int64_t sum_projected_model = 0;
- int64_t sum_projected_ctx = 0;
if (nd > 1) {
LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__);
@@ -197,12 +200,11 @@ static void llama_params_fit_impl(
const int64_t projected_used = dmd.mb.total();
const int64_t projected_free = dmd.free - projected_used;
- sum_total += dmd.total;
+ sum_free += dmd.free;
sum_projected_used += projected_used;
sum_projected_free += projected_free;
min_projected_free = std::min(min_projected_free, projected_free);
sum_projected_model += dmd.mb.model;
- sum_projected_ctx += dmd.mb.context;
if (nd > 1) {
LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n",
@@ -210,10 +212,9 @@ static void llama_params_fit_impl(
projected_free >= 0 ? "surplus" : "deficit");
}
}
- assert(sum_total >= 0 && sum_projected_used >= 0 && sum_projected_ctx >= 0);
- assert(sum_projected_used >= sum_projected_ctx);
+ assert(sum_free >= 0 && sum_projected_used >= 0);
LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n",
- __func__, sum_projected_used/MiB, sum_total/MiB);
+ __func__, sum_projected_used/MiB, sum_free/MiB);
if (min_projected_free >= margin) {
if (nd == 1) {
LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n",
@@ -236,9 +237,7 @@ static void llama_params_fit_impl(
__func__, margin/MiB, -global_surplus/MiB);
if (cparams->n_ctx == 0) {
if (hp_nct > n_ctx_min) {
- const int64_t bytes_per_ctx = sum_projected_ctx / hp_nct;
-
- int64_t memory_reduction = -global_surplus;
+ int64_t sum_used_target = sum_free - nd*margin_s;
if (nd > 1) {
// for multiple devices we need to be more conservative in terms of how much context we think can fit:
// - for dense models only whole layers can be assigned to devices
@@ -246,24 +245,34 @@ static void llama_params_fit_impl(
// - on average we expect a waste of 0.5 layers/tensors per device
// - use slightly more than the expected average for nd devices to be safe
const int64_t model_per_layer = sum_projected_model / std::min(uint32_t(mparams->n_gpu_layers), hp_ngl);
- memory_reduction += (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6);
+ sum_used_target -= (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6);
}
- uint32_t ctx_reduction = std::min(uint32_t((memory_reduction + bytes_per_ctx - 1) / bytes_per_ctx), hp_nct - n_ctx_min);
- cparams->n_ctx = hp_nct - ctx_reduction;
- cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend
+ int64_t sum_projected_used_min_ctx = 0;
+ cparams->n_ctx = n_ctx_min;
+ const dmds_t dmds_min_ctx = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
+ for (const auto & dmd : dmds_min_ctx) {
+ sum_projected_used_min_ctx += dmd.mb.total();
+ }
+ if (sum_used_target > sum_projected_used_min_ctx) {
+ // linear interpolation between minimum and maximum context size:
+ cparams->n_ctx += (hp_nct - n_ctx_min) * (sum_used_target - sum_projected_used_min_ctx)
+ / (sum_projected_used - sum_projected_used_min_ctx);
+ cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend
- ctx_reduction = hp_nct - cparams->n_ctx;
- memory_reduction = ctx_reduction * bytes_per_ctx;
- global_surplus += memory_reduction;
- LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
- __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
- if (global_surplus >= 0) {
+ const int64_t bytes_per_ctx = (sum_projected_used - sum_projected_used_min_ctx) / (hp_nct - n_ctx_min);
+ const int64_t memory_reduction = (hp_nct - cparams->n_ctx) * bytes_per_ctx;
+ LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
+ __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
if (nd == 1) {
LLAMA_LOG_INFO("%s: entire model can be fit by reducing context\n", __func__);
return;
}
LLAMA_LOG_INFO("%s: entire model should be fit across devices by reducing context\n", __func__);
+ } else {
+ const int64_t memory_reduction = sum_projected_used - sum_projected_used_min_ctx;
+ LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
+ __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
}
} else {
LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n",
@@ -276,28 +285,28 @@ static void llama_params_fit_impl(
}
if (mparams->n_gpu_layers != default_mparams.n_gpu_layers) {
- throw std::runtime_error("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort");
+ throw llama_params_fit_exception("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort");
}
if (nd > 1) {
if (!tensor_split) {
- throw std::runtime_error("did not provide a buffer to write the tensor_split to, abort");
+ throw llama_params_fit_exception("did not provide a buffer to write the tensor_split to, abort");
}
if (mparams->tensor_split) {
for (size_t id = 0; id < nd; id++) {
if (mparams->tensor_split[id] != 0.0f) {
- throw std::runtime_error("model_params::tensor_split already set by user, abort");
+ throw llama_params_fit_exception("model_params::tensor_split already set by user, abort");
}
}
}
if (mparams->split_mode == LLAMA_SPLIT_MODE_ROW) {
- throw std::runtime_error("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort");
+ throw llama_params_fit_exception("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort");
}
}
if (!tensor_buft_overrides) {
- throw std::runtime_error("did not provide buffer to set tensor_buft_overrides, abort");
+ throw llama_params_fit_exception("did not provide buffer to set tensor_buft_overrides, abort");
}
if (mparams->tensor_buft_overrides && (mparams->tensor_buft_overrides->pattern || mparams->tensor_buft_overrides->buft)) {
- throw std::runtime_error("model_params::tensor_buft_overrides already set by user, abort");
+ throw llama_params_fit_exception("model_params::tensor_buft_overrides already set by user, abort");
}
// step 3: iteratively fill the back to front with "dense" layers
@@ -380,8 +389,8 @@ static void llama_params_fit_impl(
tensor_buft_overrides[itbo].buft = nullptr;
itbo++;
mparams.tensor_buft_overrides = tensor_buft_overrides;
- throw std::runtime_error("llama_params_fit_n_tensor_buft_overrides() == "
- + std::to_string(ntbo) + " is insufficient for model\n");
+ throw llama_params_fit_exception("llama_max_tensor_buft_overrides() == "
+ + std::to_string(ntbo) + " is insufficient for model");
}
tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE);
tensor_buft_overrides[itbo].buft = overflow_bufts[id];
@@ -503,6 +512,9 @@ static void llama_params_fit_impl(
if (mem_high[id] > targets[id]) {
assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer);
uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
+ if (hp_nex > 0 && size_t(id) == nd - 1) {
+ delta--;
+ }
LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta);
while (delta > 1) {
uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
@@ -638,7 +650,7 @@ static void llama_params_fit_impl(
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP;
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__);
std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
- if (mem_test[id] < targets[id]) {
+ if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
ngl_per_device = ngl_per_device_test;
mem = mem_test;
id_dense_start = id_dense_start_test;
@@ -648,7 +660,7 @@ static void llama_params_fit_impl(
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE;
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__);
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
- if (mem_test[id] < targets[id]) {
+ if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
ngl_per_device = ngl_per_device_test;
mem = mem_test;
id_dense_start = id_dense_start_test;
@@ -659,7 +671,7 @@ static void llama_params_fit_impl(
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN;
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__);
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
- if (mem_test[id] < targets[id]) {
+ if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
ngl_per_device = ngl_per_device_test;
mem = mem_test;
id_dense_start = id_dense_start_test;
@@ -678,22 +690,25 @@ static void llama_params_fit_impl(
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams);
}
-bool llama_params_fit(
+enum llama_params_fit_status llama_params_fit(
const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) {
const int64_t t0_us = llama_time_us();
- bool ok = true;
+ llama_params_fit_status status = LLAMA_PARAMS_FIT_STATUS_SUCCESS;
try {
llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margin_s, n_ctx_min, log_level);
LLAMA_LOG_INFO("%s: successfully fit params to free device memory\n", __func__);
- } catch (const std::runtime_error & e) {
+ } catch (const llama_params_fit_exception & e) {
LLAMA_LOG_WARN("%s: failed to fit params to free device memory: %s\n", __func__, e.what());
- ok = false;
+ status = LLAMA_PARAMS_FIT_STATUS_FAILURE;
+ } catch (const std::runtime_error & e) {
+ LLAMA_LOG_ERROR("%s: encountered an error while trying to fit params to free device memory: %s\n", __func__, e.what());
+ status = LLAMA_PARAMS_FIT_STATUS_ERROR;
}
const int64_t t1_us = llama_time_us();
LLAMA_LOG_INFO("%s: fitting params to free memory took %.2f seconds\n", __func__, (t1_us - t0_us) * 1e-6);
- return ok;
+ return status;
}
struct llama_sampler_chain_params llama_sampler_chain_default_params() {
diff --git a/src/models/llama.cpp b/src/models/llama.cpp
index ab7fd5d050..42b5fcdf42 100644
--- a/src/models/llama.cpp
+++ b/src/models/llama.cpp
@@ -1,6 +1,7 @@
#include "models.h"
-llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+template
+llm_build_llama