Merge branch 'master' into imatrix

This commit is contained in:
Ed Addario 2025-12-28 10:09:11 +00:00
commit 1e6d93cc48
No known key found for this signature in database
GPG Key ID: E7875815A3230993
68 changed files with 2221 additions and 719 deletions

View File

@ -8,7 +8,8 @@ body:
value: >
Thanks for taking the time to fill out this bug report!
This issue template is intended for bug reports where the compilation of llama.cpp fails.
Before opening an issue, please confirm that the compilation still fails with `-DGGML_CCACHE=OFF`.
Before opening an issue, please confirm that the compilation still fails
after recreating the CMake build directory and with `-DGGML_CCACHE=OFF`.
If the compilation succeeds with ccache disabled you should be able to permanently fix the issue
by clearing `~/.cache/ccache` (on Linux).
- type: textarea

View File

@ -98,7 +98,18 @@ body:
label: Relevant log output
description: >
Please copy and paste any relevant log output, including the command that you entered and any generated text.
This will be automatically formatted into code, so no need for backticks.
render: shell
For very long logs (thousands of lines), preferably upload them as files instead.
On Linux you can redirect console output into a file by appending ` > llama.log 2>&1` to your command.
value: |
<details>
<summary>Logs</summary>
<!-- Copy-pasted short logs go into the "console" area here -->
```console
```
</details>
<!-- Long logs that you upload as files go here, outside the "console" area -->
validations:
required: true

View File

@ -85,8 +85,19 @@ body:
label: Relevant log output
description: >
If applicable, please copy and paste any relevant log output, including any generated text.
This will be automatically formatted into code, so no need for backticks.
If you are encountering problems specifically with the `llama_params_fit` module, always upload `--verbose` logs as well.
render: shell
For very long logs (thousands of lines), please upload them as files instead.
On Linux you can redirect console output into a file by appending ` > llama.log 2>&1` to your command.
value: |
<details>
<summary>Logs</summary>
<!-- Copy-pasted short logs go into the "console" area here -->
```console
```
</details>
<!-- Long logs that you upload as files go here, outside the "console" area -->
validations:
required: false

View File

@ -2087,7 +2087,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"override tensor buffer type", [](common_params & params, const std::string & value) {
parse_tensor_buffer_overrides(value, params.tensor_buft_overrides);
}
));
).set_env("LLAMA_ARG_OVERRIDE_TENSOR"));
add_opt(common_arg(
{"-otd", "--override-tensor-draft"}, "<tensor name pattern>=<buffer type>,...",
"override tensor buffer type for draft model", [](common_params & params, const std::string & value) {
@ -2137,11 +2137,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_N_CPU_MOE_DRAFT"));
GGML_ASSERT(params.n_gpu_layers < 0); // string_format would need to be extended for a default >= 0
add_opt(common_arg(
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
string_format("max. number of layers to store in VRAM (default: %d)", params.n_gpu_layers),
[](common_params & params, int value) {
params.n_gpu_layers = value;
string_format("max. number of layers to store in VRAM, either an exact number, 'auto', or 'all' (default: %s)", params.n_gpu_layers == -1 ? "auto" : "all"),
[](common_params & params, const std::string & value) {
if (value == "auto") {
params.n_gpu_layers = -1;
} else if (value == "all") {
params.n_gpu_layers = -2;
} else {
params.n_gpu_layers = std::stoi(value);
}
if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: no usable GPU found, --gpu-layers option will be ignored\n");
fprintf(stderr, "warning: one possible reason is that llama.cpp was compiled without GPU support\n");
@ -3175,11 +3182,19 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.devices = parse_device_list(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
GGML_ASSERT(params.speculative.n_gpu_layers < 0); // string_format would need to be extended for a default >= 0
add_opt(common_arg(
{"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N",
"number of layers to store in VRAM for the draft model",
[](common_params & params, int value) {
params.speculative.n_gpu_layers = value;
string_format("max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: %s)",
params.speculative.n_gpu_layers == -1 ? "auto" : "all"),
[](common_params & params, const std::string & value) {
if (value == "auto") {
params.speculative.n_gpu_layers = -1;
} else if (value == "all") {
params.speculative.n_gpu_layers = -2;
} else {
params.speculative.n_gpu_layers = std::stoi(value);
}
if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: no usable GPU found, --gpu-layers-draft option will be ignored\n");
fprintf(stderr, "warning: one possible reason is that llama.cpp was compiled without GPU support\n");
@ -3518,15 +3533,15 @@ void common_params_add_preset_options(std::vector<common_arg> & args) {
[](common_params &, const std::string &) { /* unused */ }
).set_env(COMMON_ARG_PRESET_LOAD_ON_STARTUP).set_preset_only());
args.push_back(common_arg(
{"stop-timeout"}, "SECONDS",
"in server router mode, force-kill model instance after this many seconds of graceful shutdown",
[](common_params &, int) { /* unused */ }
).set_env(COMMON_ARG_PRESET_STOP_TIMEOUT).set_preset_only());
// args.push_back(common_arg(
// {"pin"},
// "in server router mode, do not unload this model if models_max is exceeded",
// [](common_params &) { /* unused */ }
// ).set_preset_only());
// args.push_back(common_arg(
// {"unload-idle-seconds"}, "SECONDS",
// "in server router mode, unload models idle for more than this many seconds",
// [](common_params &, int) { /* unused */ }
// ).set_preset_only());
}

View File

@ -10,6 +10,7 @@
// pseudo-env variable to identify preset-only arguments
#define COMMON_ARG_PRESET_LOAD_ON_STARTUP "__PRESET_LOAD_ON_STARTUP"
#define COMMON_ARG_PRESET_STOP_TIMEOUT "__PRESET_STOP_TIMEOUT"
//
// CLI argument parsing

View File

@ -1341,10 +1341,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
mparams.devices = params.devices.data();
}
if (params.n_gpu_layers != -1) {
mparams.n_gpu_layers = params.n_gpu_layers;
}
mparams.n_gpu_layers = params.n_gpu_layers;
mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode;
mparams.tensor_split = params.tensor_split;

View File

@ -329,7 +329,7 @@ struct common_params {
// offload params
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
bool fit_params = true; // whether to fit unset model/context parameters to free device memory

View File

@ -7362,6 +7362,90 @@ class MiniMaxM2Model(TextModel):
return super().modify_tensors(data_torch, name, bid)
@ModelBase.register("MiMoV2FlashForCausalLM")
class MimoV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.MIMO2
def set_gguf_parameters(self):
super().set_gguf_parameters()
assert self.hparams["swa_head_dim"] == self.hparams["head_dim"]
assert self.hparams["swa_num_attention_heads"] == self.hparams["num_attention_heads"]
assert self.hparams["swa_v_head_dim"] == self.hparams["v_head_dim"]
assert self.hparams["topk_method"] == "noaux_tc"
n_head_kv = self.hparams["num_key_value_heads"]
n_head_kv_swa = self.hparams["swa_num_key_value_heads"]
n_head_kv_arr = [n_head_kv_swa if use_swa == 1 else n_head_kv for use_swa in self.hparams["hybrid_layer_pattern"]]
self.gguf_writer.add_head_count_kv(n_head_kv_arr)
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
self.gguf_writer.add_sliding_window_pattern(self.hparams["hybrid_layer_pattern"])
self.gguf_writer.add_rope_freq_base_swa(self.hparams["swa_rope_theta"])
self.gguf_writer.add_value_length(self.hparams["v_head_dim"])
self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"])
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
rope_dim = int(self.hparams["head_dim"] * self.hparams["partial_rotary_factor"])
self.gguf_writer.add_rope_dimension_count(rope_dim)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon", 1e-5))
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch, name, bid):
if name.endswith("e_score_correction_bias"):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
if "attention_sink" in name and not name.endswith(".weight"):
name += ".weight"
# TODO: mimo v2 does not indicate the number of next-token-prediction layers, therefore we cannot do the same way as GLM4_MOE
if "model.mtp." in name:
return []
# process the experts separately
if name.find("mlp.experts") != -1:
n_experts = self.hparams["n_routed_experts"]
assert bid is not None
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []
# merge the experts into a single 3d tensor
for w_name in ["gate_proj", "up_proj", "down_proj"]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename_to_retrieve = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename_to_retrieve])
del self._experts[bid][ename_to_retrieve]
data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))
return tensors
else:
return []
return [(self.map_tensor_name(name), data_torch)]
def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("PanguEmbeddedForCausalLM")
class PanguEmbeddedModel(TextModel):
model_arch = gguf.MODEL_ARCH.PANGU_EMBED
@ -8695,6 +8779,11 @@ class NemotronHModel(GraniteHybridModel):
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("LlamaBidirectionalModel")
class LlamaEmbedNemotronModel(LlamaModel):
model_arch = gguf.MODEL_ARCH.LLAMA_EMBED
@ModelBase.register("BailingMoeForCausalLM")
class BailingMoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.BAILINGMOE

View File

@ -17,7 +17,7 @@ OpenCL (Open Computing Language) is an open, royalty-free standard for cross-pla
### Llama.cpp + OpenCL
The llama.cpp OpenCL backend is designed to enable llama.cpp on **Qualcomm Adreno GPU** firstly via OpenCL. Thanks to the portabilty of OpenCL, the OpenCL backend can also run on certain Intel GPUs although the performance is not optimal.
The llama.cpp OpenCL backend is designed to enable llama.cpp on **Qualcomm Adreno GPU** firstly via OpenCL. Thanks to the portabilty of OpenCL, the OpenCL backend can also run on certain Intel GPUs such as those that do not have [SYCL](/docs/backend/SYCL.md) support although the performance is not optimal.
## OS

View File

@ -829,7 +829,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
No. We can't support Ollama issue directly, because we aren't familiar with Ollama.
Sugguest reproducing on llama.cpp and report similar issue to llama.cpp. We will surpport it.
Suggest reproducing on llama.cpp and report similar issue to llama.cpp. We will support it.
It's same for other projects including llama.cpp SYCL backend.

View File

@ -430,10 +430,22 @@ if (MSVC)
configure_msvc_target(ggml-cpu-x64)
configure_msvc_target(ggml-cpu-sse42)
configure_msvc_target(ggml-cpu-sandybridge)
# __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
# skipping ggml-cpu-ivybridge
# skipping ggml-cpu-piledriver
configure_msvc_target(ggml-cpu-haswell)
configure_msvc_target(ggml-cpu-skylakex)
configure_msvc_target(ggml-cpu-cannonlake)
configure_msvc_target(ggml-cpu-cascadelake)
configure_msvc_target(ggml-cpu-icelake)
# MSVC 2022 doesn't support BF16 intrinsics without `/arch:AVX10.1` ?!
# https://learn.microsoft.com/en-us/cpp/intrinsics/x64-amd64-intrinsics-list?view=msvc-170
# https://learn.microsoft.com/en-us/cpp/build/reference/arch-x64?view=msvc-170
# skipping ggml-cpu-cooperlake
# skipping ggml-cpu-zen4
configure_msvc_target(ggml-cpu-alderlake)
# MSVC doesn't support AMX
# skipping ggml-cpu-sapphirerapids
if (GGML_BUILD_EXAMPLES)
configure_msvc_target(common-ggml)

View File

@ -357,15 +357,29 @@ if (GGML_CPU_ALL_VARIANTS)
endif()
if (GGML_SYSTEM_ARCH STREQUAL "x86")
ggml_add_cpu_backend_variant(x64)
ggml_add_cpu_backend_variant(sse42 SSE42)
ggml_add_cpu_backend_variant(sandybridge SSE42 AVX)
ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C AVX2 BMI2 FMA)
ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512)
ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI)
ggml_add_cpu_backend_variant(sse42 SSE42)
ggml_add_cpu_backend_variant(sandybridge SSE42 AVX)
if (NOT MSVC)
# __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
ggml_add_cpu_backend_variant(ivybridge SSE42 AVX F16C)
ggml_add_cpu_backend_variant(piledriver SSE42 AVX F16C FMA)
endif()
ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C FMA AVX2 BMI2)
ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C FMA AVX2 BMI2 AVX512)
ggml_add_cpu_backend_variant(cannonlake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI)
ggml_add_cpu_backend_variant(cascadelake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VNNI)
ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI)
if (NOT MSVC)
# MSVC 2022 doesn't support BF16 intrinsics without `/arch:AVX10.1` ?!
# https://learn.microsoft.com/en-us/cpp/intrinsics/x64-amd64-intrinsics-list?view=msvc-170
# https://learn.microsoft.com/en-us/cpp/build/reference/arch-x64?view=msvc-170
ggml_add_cpu_backend_variant(cooperlake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VNNI AVX512_BF16)
ggml_add_cpu_backend_variant(zen4 SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16)
endif()
ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C FMA AVX2 BMI2 AVX_VNNI)
if (NOT MSVC)
# MSVC doesn't support AMX
ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
endif()
elseif(GGML_SYSTEM_ARCH STREQUAL "ARM")
if (CMAKE_SYSTEM_NAME MATCHES "Linux")

View File

@ -2338,19 +2338,19 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
// Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
// TODO: acl_yarn_ramp_tensor use rope cache.
bool yarn_ramp_tensor_updated = false;
ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
acl_tensor_ptr acl_yarn_ramp_tensor;
if (ext_factor != 0 && (theta_scale_updated || ctx.rope_cache.theta_scale_length != theta_scale_length ||
ctx.rope_cache.freq_scale != freq_scale)) {
yarn_ramp_tensor_updated = true;
if (ctx.rope_cache.yarn_ramp_cache != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.yarn_ramp_cache));
}
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
// -rope_yarn_ramp
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
// return MIN(1, MAX(0, y)) - 1;
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
void * yarn_ramp_buffer = yarn_ramp_allocator.get();
acl_yarn_ramp_tensor =
ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
float zero_value = 0, one_value = 1;
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
@ -2380,8 +2380,10 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
} else {
acl_yarn_ramp_tensor =
ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
}
// Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
if (ext_factor != 0) {
if (theta_scale_updated || yarn_ramp_tensor_updated) {
@ -2988,32 +2990,156 @@ void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src.get(), 3, false, acl_dst.get());
}
void ggml_cann_conv_transpose_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){
ggml_tensor * src0 = dst->src[0];
ggml_tensor * src1 = dst->src[1];
// stride
int64_t s0 = ((const int32_t *) (dst->op_params))[0];
int64_t s0 = ((const int32_t*)(dst->op_params))[0];
acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL);
acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL);
acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL);
// get base information of input and kernel
int64_t input_len = *(src1->ne);
int64_t dst_len = *(dst->ne);
int64_t kernel_size = *(src0->ne);
// set the max kernel size for each conv
int64_t max_kernel_size = 255;
// compute the partition of kernel
int64_t part_num = 1;
part_num = (kernel_size + max_kernel_size - 1) / max_kernel_size;
int64_t strideVal[1];
strideVal[0] = s0;
acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1);
int64_t paddingVal[] = { 0 };
acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1);
int64_t dilationVal[] = { 1 };
acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);
int8_t cubeMathType = 0;
strideVal[0] = s0;
acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1);
int64_t paddingVal[] = {0};
acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1);
int64_t dilationVal[] = {1};
acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);
bool transposed = true;
int64_t groups = 1;
int8_t cubeMathType = 0;
#ifdef ASCEND_310P
cubeMathType = 1;
#endif
GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), acl_weight.get(), nullptr, stride.get(), padding.get(),
dilation.get(), true, padding.get(), 1, acl_dst.get(), cubeMathType);
auto weight_type = ggml_cann_type_mapping(src0->type);
auto dst_type = ggml_cann_type_mapping(dst->type);
// slice the kernel to make each conv available
int64_t slice_dim = -1;
int64_t slice_start = 0;
int64_t slice_end = max_kernel_size;
int64_t slice_step = 1;
int64_t interval = max_kernel_size;
int64_t left_pad_len = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0];
int64_t right_pad_len = 0;
acl_scalar_ptr alpha = nullptr;
float alphaValue = 1.0;
alpha = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT);
// set zero to destination
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
for(int k = 0; k < part_num; k++){
// create part kernel tensor and slice from big kernel
slice_start = max_kernel_size * k;
if(k == part_num - 1){
slice_end = kernel_size;
interval = kernel_size - max_kernel_size * k;
}else{
slice_end = max_kernel_size * (k+1);
}
int64_t part_ne[4];
for(int i = 0; i < 4; i++) {
part_ne[i] = *(src0->ne + i);
}
part_ne[0] = interval;
size_t part_nb[4];
part_nb[0] = sizeof(weight_type);
for (int i = 1; i < 4; i++) {
part_nb[i] = part_nb[i - 1] * part_ne[i - 1];
}
ggml_cann_pool_alloc part_kernel_allocator;
part_kernel_allocator.alloc(ctx.pool(), part_nb[3]);
void* part_kernel_buf = part_kernel_allocator.get();
acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type,
ggml_element_size(src0), part_ne, part_nb, 3, ACL_FORMAT_NCL);
GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step, part_kernel.get());
// create the part conv result tensor
int64_t part_dst_ne[4];
for(int i = 0; i < 4; i++){
part_dst_ne[i] = *(dst->ne + i);
}
part_dst_ne[0] = (input_len - 1) * strideVal[0] - 2 * paddingVal[0] + dilationVal[0] * (part_ne[0] - 1) + 1;
size_t part_dst_nb[4];
part_dst_nb[0] = sizeof(weight_type);
for (int i = 1; i < 4; i++) {
part_dst_nb[i] = part_dst_nb[i - 1] * part_dst_ne[i - 1];
}
ggml_cann_pool_alloc part_dst_allocator;
part_dst_allocator.alloc(ctx.pool(), part_dst_nb[3]);
void* part_dst_buf = part_dst_allocator.get();
acl_tensor_ptr acl_part_dst = ggml_cann_create_tensor(part_dst_buf, dst_type, ggml_element_size(dst),
part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_part_dst.get());
// compute part conv transpose 1d
GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), part_kernel.get(), nullptr, stride.get(),
padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(), cubeMathType);
// compute the position of part result in final result
int64_t global_start = slice_start;
int64_t global_end = std::min((input_len - 1) * strideVal[0] + slice_end, dst_len);
left_pad_len = global_start;
right_pad_len = dst_len - global_end;
std::vector<int64_t> padDataVal = {left_pad_len,right_pad_len};
acl_int_array_ptr padData = ggml_cann_create_int_array(padDataVal.data(), 2);
acl_scalar_ptr pad_value = nullptr;
float pad_valueVal = 0.0;
pad_value = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT);
int64_t conv_result_ne[4];
for(int i = 0; i < 4; i++){
conv_result_ne[i] = *(dst->ne + i);
}
size_t conv_result_nb[4];
conv_result_nb[0] = sizeof(weight_type);
for (int i = 1; i < 4; i++) {
conv_result_nb[i] = conv_result_nb[i - 1] * conv_result_ne[i - 1];
}
ggml_cann_pool_alloc conv_result_allocator;
conv_result_allocator.alloc(ctx.pool(), conv_result_nb[3]);
void* conv_result_buf = conv_result_allocator.get();
acl_tensor_ptr conv_result = ggml_cann_create_tensor(conv_result_buf, dst_type, ggml_element_size(dst),
conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, conv_result.get());
GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(), conv_result.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst.get(), conv_result.get(), alpha.get());
}
}
void ggml_cann_elu(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
@ -3576,3 +3702,106 @@ void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
break;
}
}
void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; // conv_x
ggml_tensor * src1 = dst->src[1]; // conv1d.weight
// This op is currently defined only for F32 in ggml_cpu
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
// Shapes follow ggml_compute_forward_ssm_conv_f32
const int64_t nc = src1->ne[0]; // d_conv
const int64_t ncs = src0->ne[0]; // d_conv - 1 + n_t
const int64_t nr = src0->ne[1]; // d_inner
const int64_t n_s = src0->ne[2]; // n_seqs
const int64_t n_t = dst->ne[1]; // tokens per sequence
GGML_ASSERT(dst->ne[0] == nr); // dst: {d_inner, n_t, n_s}
GGML_ASSERT(src1->ne[1] == nr); // weight: {d_conv, d_inner}
GGML_ASSERT(ncs == nc - 1 + n_t); // conv_x: {d_conv - 1 + n_t, d_inner, n_s}
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
// --- Build CANN tensors ---
// 1) Input: conv_x as NCL
//
// src0->ne = { ncs, nr, n_s, 1 } // {L_in, C, N}
// Passing ACL_FORMAT_NCL here means:
// reversed dims -> [N, C, L_in] = [n_s, nr, ncs]
acl_tensor_ptr acl_x = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL);
// 2) Weights: depthwise conv kernel, view src1 as {K, 1, C}
//
// src1 original: ne = { nc, nr, 1, 1 } // [K, C, 1, 1]
// we want a view: ne_w = { nc, 1, nr } // [K, 1, C]
// so that reversed dims -> [C, 1, K] which matches
// [out_channels, in_channels/groups, kernel_size]
int64_t w_ne[GGML_MAX_DIMS] = { nc, 1, nr, 1 }; // [K, 1 input ch. per group, C groups]
// Layout: src1 data is [K, C] with
// offset(k, c) = k*nb0 + c*nb1
// We want offset_w(k, 0, c) = k*nb0 + c*nb1,
// so we can reuse nb0 and nb1, and set nb2 = nb1.
size_t w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1
acl_tensor_ptr acl_w = ggml_cann_create_tensor(
src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL);
// 3) Output: dst is { d_inner, n_t, n_s } (CLN)
//
// We need an NCL view of the same buffer:
// desired NCL logical shape: { L_out = n_t, C = nr, N = n_s }
//
// Original CLN layout:
// dst->ne = { nr, n_t, n_s }
// dst->nb[0] = sizeof(float)
// dst->nb[1] = nr * sizeof(float)
// dst->nb[2] = nr * n_t * sizeof(float)
//
// We want offset_new(L, C, N) = offset_orig(C, L, N).
// Choose:
// nb_y[0] = nr * sizeof(float); // step in L
// nb_y[1] = sizeof(float); // step in C
// nb_y[2] = nr * n_t * sizeof(float); // step in N
int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N]
size_t y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float), dst->nb[3] }; // [nr, 1, nr * n_t]
acl_tensor_ptr acl_y = ggml_cann_create_tensor(
dst->data, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL);
// --- Conv1d parameters: depthwise, stride 1, no padding ("valid") ---
int64_t strideVal[1] = { 1 };
int64_t paddingVal[1] = { 0 };
int64_t dilationVal[1] = { 1 };
acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1);
acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1);
acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);
const bool transposed = false;
const int64_t groups = nr; // depthwise: one group per inner dim
int8_t cubeMathType = 0;
#ifdef ASCEND_310P
cubeMathType = 1;
#endif
GGML_CANN_CALL_ACLNN_OP(ctx,
Convolution,
acl_x.get(), // input: N, C, L_in = ncs
acl_w.get(), // weight: [C, 1, K] with groups=nr
nullptr, // bias
stride.get(),
padding.get(),
dilation.get(),
transposed,
padding.get(), // output padding (unused for non-transposed)
groups,
acl_y.get(),
cubeMathType);
}

View File

@ -47,6 +47,7 @@
#include <aclnnop/aclnn_sign.h>
#include <aclnnop/aclnn_silu.h>
#include <aclnnop/aclnn_sin.h>
#include <aclnnop/aclnn_slice.h>
#include <aclnnop/aclnn_sqrt.h>
#include <aclnnop/aclnn_tanh.h>
@ -1032,6 +1033,8 @@ void ggml_cann_op_unary(std::function<void(ggml_backend_cann_context &, aclTenso
ggml_backend_cann_context & ctx,
ggml_tensor * dst);
void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Applies a gated (GLU-style) unary operation using the CANN backend.
*

View File

@ -229,6 +229,60 @@ struct ggml_graph_node_properties {
// op
ggml_op node_op;
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
/**
* @brief Check if a ggml tensor node matches this property set.
*
* This function compares all relevant fields (address, op type, shape, source inputs, op params)
* to determine whether the current node matches these previously recorded properties.
*
* @param node The current ggml tensor node.
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
*/
bool has_matching_properties(ggml_tensor * node) {
if (node->data != this->node_address && node->op != GGML_OP_VIEW) {
return false;
}
if (node->op != this->node_op) {
return false;
}
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (node->ne[i] != this->ne[i]) {
return false;
}
if (node->nb[i] != this->nb[i]) {
return false;
}
}
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node->src[i]) {
if (node->src[i]->data != this->src_address[i] && node->op != GGML_OP_VIEW) {
return false;
}
for (int d = 0; d < GGML_MAX_DIMS; d++) {
if (node->src[i]->ne[d] != this->src_ne[i][d]) {
return false;
}
if (node->src[i]->nb[d] != this->src_nb[i][d]) {
return false;
}
}
} else {
if (this->src_address[i] != nullptr) {
return false;
}
}
}
if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
}
return true;
}
};
struct ggml_cann_graph {
@ -241,6 +295,79 @@ struct ggml_cann_graph {
aclmdlRI graph = nullptr;
std::vector<ggml_graph_node_properties> ggml_graph_properties;
/**
* @brief Create a new CANN graph from a ggml computation graph.
*
* This function creates a new ggml_cann_graph object and fills its node properties
* (operation type, dimensions, strides, input sources, and operation parameters)
* based on the current ggml computation graph.
*
* Each node in the ggml graph is mapped to a property entry in the new CANN graph:
* - node address
* - operation type
* - shape (ne) and strides (nb)
* - source tensor addresses
* - operation parameters
*
* @param cgraph The current ggml computation graph.
* @return Pointer to the newly created ggml_cann_graph object.
*/
static ggml_cann_graph * create_from_cgraph(ggml_cgraph * cgraph) {
ggml_cann_graph * new_graph = new ggml_cann_graph();
new_graph->ggml_graph_properties.resize(cgraph->n_nodes);
for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {
ggml_tensor * node = cgraph->nodes[node_idx];
auto & prop = new_graph->ggml_graph_properties[node_idx];
prop.node_address = node->data;
prop.node_op = node->op;
std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);
std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
for (int src = 0; src < GGML_MAX_SRC; ++src) {
if (node->src[src]) {
prop.src_address[src] = node->src[src]->data;
std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
} else {
prop.src_address[src] = nullptr;
std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
}
}
memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
}
return new_graph;
}
/**
* @brief Check whether this CANN graph matches the given ggml computation graph.
*
* This function compares the number of nodes and each node's properties
* (operation type, dimensions, strides, inputs, and operation parameters)
* to determine whether this CANN graph matches the given ggml graph.
*
* @param cgraph The current ggml computation graph.
* @return true if this CANN graph matches the ggml graph; false otherwise.
*/
bool matches_cgraph(ggml_cgraph * cgraph) {
if (this->ggml_graph_properties.size() != static_cast<size_t>(cgraph->n_nodes)) {
return false;
}
for (int i = 0; i < cgraph->n_nodes; ++i) {
if (!this->ggml_graph_properties[i].has_matching_properties(cgraph->nodes[i])) {
return false;
}
}
return true;
}
};
/**
@ -272,15 +399,6 @@ struct ggml_cann_graph_lru_cache {
cache_list.push_front(new_node);
}
/**
* @brief Move an existing graph to the front of the cache.
* @param node Pointer to the ggml_cann_graph to move.
*/
void move_to_front(ggml_cann_graph * node) {
cache_list.remove(node);
cache_list.push_front(node);
}
/**
* @brief Clear all graphs from the cache (also frees memory).
*/
@ -295,6 +413,28 @@ struct ggml_cann_graph_lru_cache {
* @brief Destructor that clears the cache and frees all cached graphs.
*/
~ggml_cann_graph_lru_cache() { clear(); }
/**
* @brief Find a cached CANN graph that matches the given ggml graph and move it to front.
*
* This function iterates through the cached CANN graphs stored in the LRU cache and
* compares them against the given ggml computation graph. If a matching graph is found,
* it is promoted to the front of the LRU cache and returned. Otherwise, the function
* returns nullptr.
*
* @param cgraph The current ggml computation graph.
* @return true if found; false otherwise.
*/
bool find_and_move_to_front(ggml_cgraph * cgraph) {
for (auto & graph_ptr : this->cache_list) {
if (graph_ptr->matches_cgraph(cgraph)) {
cache_list.remove(graph_ptr);
cache_list.push_front(graph_ptr);
return true;
}
}
return false;
}
};
#endif // USE_ACL_GRAPH
@ -318,6 +458,9 @@ struct ggml_cann_rope_cache {
if (position_select_index_host) {
free(position_select_index_host);
}
if (yarn_ramp_cache) {
ACL_CHECK(aclrtFree(yarn_ramp_cache));
}
}
bool equal(int64_t theta_scale_length,
@ -370,6 +513,7 @@ struct ggml_cann_rope_cache {
float * theta_scale_exp_host = nullptr;
int * position_select_index_host = nullptr;
void * position_select_index = nullptr;
void * yarn_ramp_cache = nullptr;
// sin/cos cache, used only to accelerate first layer on each device
void * sin_cache = nullptr;
void * cos_cache = nullptr;

View File

@ -1888,6 +1888,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
break;
case GGML_OP_OUT_PROD:
ggml_cann_out_prod(ctx, dst);
case GGML_OP_SSM_CONV:
ggml_cann_ssm_conv(ctx, dst);
break;
default:
return false;
@ -2075,162 +2077,6 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
}
#ifdef USE_ACL_GRAPH
/**
* @brief Add a new CANN graph to the LRU cache by populating node properties from the ggml graph.
*
* This function creates a new ggml_cann_graph object and fills its node properties
* (operation type, dimensions, strides, input sources, and operation parameters)
* based on the current ggml computation graph.
*
* Each node in the ggml graph is mapped to a property entry in the new CANN graph:
* - node address
* - operation type
* - shape (ne) and strides (nb)
* - source tensor addresses
* - operation parameters
*
* After initialization, the new graph is pushed into the LRU cache owned by the
* CANN backend context. The cache takes ownership of the graph and manages its
* lifetime (including deletion upon eviction).
*
* @param cann_ctx The CANN backend context containing the graph cache.
* @param cgraph The current ggml computation graph.
*/
static void add_lru_matched_graph_node_properties(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
// Create a new ggml_cann_graph object on the heap (its lifetime is managed by the cache).
ggml_cann_graph * new_graph = new ggml_cann_graph();
new_graph->ggml_graph_properties.resize(cgraph->n_nodes);
for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {
ggml_tensor * node = cgraph->nodes[node_idx];
auto & prop = new_graph->ggml_graph_properties[node_idx];
prop.node_address = node->data;
prop.node_op = node->op;
std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);
std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
for (int src = 0; src < GGML_MAX_SRC; ++src) {
if (node->src[src]) {
prop.src_address[src] = node->src[src]->data;
std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
} else {
prop.src_address[src] = nullptr;
std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
}
}
memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
}
// Insert into the LRU cache (cache takes ownership and will delete it when evicted).
cann_ctx->graph_lru_cache.push(new_graph);
}
/**
* @brief Check if a ggml tensor node matches a previously captured CANN graph node.
*
* This function compares all relevant fields (address, op type, shape, source inputs, op params)
* to determine whether the current node matches a previously recorded version.
*
* @param node The current ggml tensor node.
* @param graph_node_properties The stored properties of a CANN graph node.
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
*/
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node,
ggml_graph_node_properties * graph_node_properties) {
if (node->data != graph_node_properties->node_address && node->op != GGML_OP_VIEW) {
return false;
}
if (node->op != graph_node_properties->node_op) {
return false;
}
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (node->ne[i] != graph_node_properties->ne[i]) {
return false;
}
if (node->nb[i] != graph_node_properties->nb[i]) {
return false;
}
}
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node->src[i]) {
if (node->src[i]->data != graph_node_properties->src_address[i] && node->op != GGML_OP_VIEW) {
return false;
}
for (int d = 0; d < GGML_MAX_DIMS; d++) {
if (node->src[i]->ne[d] != graph_node_properties->src_ne[i][d]) {
return false;
}
if (node->src[i]->nb[d] != graph_node_properties->src_nb[i][d]) {
return false;
}
}
} else {
if (graph_node_properties->src_address[i] != nullptr) {
return false;
}
}
}
if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
return memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
}
return true;
}
/**
* @brief Check whether there is a cached CANN graph that matches the current ggml graph.
*
* This function iterates through the cached CANN graphs stored in the LRU cache and
* compares them against the given ggml computation graph. A match requires that the
* number of nodes is the same and that each nodes properties (operation type,
* dimensions, strides, inputs, and operation parameters) are identical.
*
* If a matching graph is found, it is promoted to the front of the LRU cache and the
* function returns true. Otherwise, the function returns false, indicating that a new
* CANN graph needs to be captured.
*
* @param cann_ctx The CANN backend context containing the graph cache.
* @param cgraph The current ggml computation graph.
* @return true if a matching cached graph exists; false otherwise.
*/
static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
ggml_cann_graph_lru_cache & lru_cache = cann_ctx->graph_lru_cache;
for (auto & graph_ptr : lru_cache.cache_list) {
// Skip graphs with a different number of nodes.
if (graph_ptr->ggml_graph_properties.size() != static_cast<size_t>(cgraph->n_nodes)) {
continue;
}
// Check if all nodes match.
bool all_match = true;
for (int i = 0; i < cgraph->n_nodes; ++i) {
if (!ggml_graph_node_has_matching_properties(cgraph->nodes[i], &graph_ptr->ggml_graph_properties[i])) {
all_match = false;
break;
}
}
if (all_match) {
// update cache_list && renturn graph_ptr
lru_cache.move_to_front(graph_ptr);
return true;
}
}
return false;
}
#endif // USE_ACL_GRAPH
/**
* @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
*
@ -2239,23 +2085,23 @@ static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph *
*
* Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.
*
* @param cann_ctx The CANN backend context.
* @param cgraph The ggml computation graph.
* @param use_cann_graph Whether to use CANN graph execution.
* @param cann_graph_update_required Whether graph capture is needed due to graph changes.
* @param cann_ctx The CANN backend context.
* @param cgraph The ggml computation graph.
* @param use_cann_graph Whether to use CANN graph execution.
* @param cann_graph_capture_required Whether graph capture is needed due to graph changes.
*/
static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx,
ggml_cgraph * cgraph,
bool & use_cann_graph,
bool & cann_graph_update_required) {
bool use_cann_graph,
bool cann_graph_capture_required) {
#ifdef USE_ACL_GRAPH
if (use_cann_graph && cann_graph_update_required) { // Begin CANN graph capture
if (use_cann_graph && cann_graph_capture_required) { // Begin CANN graph capture
ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
}
#endif // USE_ACL_GRAPH
// Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
// With the use of CANN graphs, the execution will be performed by the graph launch.
if (!use_cann_graph || cann_graph_update_required) {
if (!use_cann_graph || cann_graph_capture_required) {
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
@ -2274,9 +2120,10 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
#ifdef USE_ACL_GRAPH
if (use_cann_graph) {
GGML_ASSERT(!cann_ctx->graph_lru_cache.cache_list.empty());
ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
if (cann_graph_update_required) { // End CANN graph capture
if (cann_graph_capture_required) { // End CANN graph capture
ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
}
@ -2306,7 +2153,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
// calculate rope cache for fist layer in current device.
cann_ctx->rope_cache.cached = false;
bool cann_graph_update_required = false;
bool graph_capture_required = false;
#ifdef USE_ACL_GRAPH
bool use_cann_graph = true;
@ -2331,16 +2178,17 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
if (use_cann_graph) {
// If no matching graph is found, the graph needs to be recaptured.
cann_graph_update_required = !is_matched_graph(cann_ctx, cgraph);
if (cann_graph_update_required) {
graph_capture_required = !cann_ctx->graph_lru_cache.find_and_move_to_front(cgraph);
if (graph_capture_required) {
// If no matching graph is found, add a new ACL graph.
add_lru_matched_graph_node_properties(cann_ctx, cgraph);
ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph);
cann_ctx->graph_lru_cache.push(new_graph);
}
}
#else
bool use_cann_graph = false;
#endif // USE_ACL_GRAPH
evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, cann_graph_update_required);
evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, graph_capture_required);
return GGML_STATUS_SUCCESS;
}
@ -2578,8 +2426,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
}
}
case GGML_OP_CONV_TRANSPOSE_1D:
// TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255.
return (op->src[0]->ne[0] - 1) <= 255;
return true;
case GGML_OP_SCALE:
float bias;
memcpy(&bias, (const float *) (op->op_params) + 1, sizeof(float));
@ -2626,6 +2473,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
}
return true;
}
case GGML_OP_SSM_CONV:
return true;
default:
return false;
}

View File

@ -328,7 +328,7 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <intrin.h>
#elif defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
#elif defined(__SSE__) || defined(__SSE3__) || defined(__SSSE3__) || defined(__AVX__) || defined(__F16C__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX512BF16__)
#include <immintrin.h>
#endif

View File

@ -14,10 +14,6 @@
#include <arm_neon.h>
#endif
#if defined(__F16C__)
#include <immintrin.h>
#endif
#if defined(__riscv_v_intrinsic)
#include <riscv_vector.h>
#endif

View File

@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND)
# 80 == Ampere, asynchronous data loading, faster tensor core instructions
# 86 == RTX 3000, needs CUDA v11.1
# 89 == RTX 4000, needs CUDA v11.8
# 120 == Blackwell, needs CUDA v12.8, FP4 tensor cores
#
# XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run
# XX-real == compile CUDA code as device code for this specific architecture
@ -36,10 +37,36 @@ if (CUDAToolkit_FOUND)
endif()
endif()
endif()
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
enable_language(CUDA)
# Replace any 12x-real architectures with 12x{a}-real. FP4 ptx instructions are not available in just 12x
if (GGML_NATIVE)
set(PROCESSED_ARCHITECTURES "")
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES AND CMAKE_CUDA_ARCHITECTURES_NATIVE)
set(ARCH_LIST ${CMAKE_CUDA_ARCHITECTURES_NATIVE})
else()
set(ARCH_LIST ${CMAKE_CUDA_ARCHITECTURES})
endif()
foreach(ARCH ${ARCH_LIST})
if (ARCH MATCHES "^12[0-9](-real|-virtual)?$")
string(REGEX REPLACE "^(12[0-9]).*$" "\\1" BASE_ARCH ${ARCH})
message(STATUS "Replacing ${ARCH} with ${BASE_ARCH}a-real")
list(APPEND PROCESSED_ARCHITECTURES "${BASE_ARCH}a-real")
else()
list(APPEND PROCESSED_ARCHITECTURES ${ARCH})
endif()
endforeach()
set(CMAKE_CUDA_ARCHITECTURES ${PROCESSED_ARCHITECTURES})
else()
foreach(ARCH ${CMAKE_CUDA_ARCHITECTURES})
if(ARCH MATCHES "^12[0-9](-real|-virtual)?$")
message(FATAL_ERROR "Compute capability ${ARCH} used, use ${ARCH}a or ${ARCH}f for Blackwell specific optimizations")
endif()
endforeach()
endif()
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
file(GLOB GGML_HEADERS_CUDA "*.cuh")
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")

View File

@ -50,6 +50,10 @@
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_ADA_LOVELACE 890
// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see
// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms
#define GGML_CUDA_CC_BLACKWELL 1200
#define GGML_CUDA_CC_RUBIN 1300
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
@ -246,6 +250,10 @@ static const char * cu_get_error_str(CUresult err) {
#define AMPERE_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL && __CUDA_ARCH__ < GGML_CUDA_CC_RUBIN
# define BLACKWELL_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define CP_ASYNC_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
@ -316,6 +324,11 @@ static bool cp_async_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
}
static bool blackwell_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_BLACKWELL &&
ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_RUBIN;
}
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
return 64;
@ -701,6 +714,28 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
#endif // CUDART_VERSION >= 12050
}
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
const uint8_t sign_bit = (x < 0.0f) << 3;
float ax = fabsf(x) * e;
// Positive LUT
static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f };
int best_i = 0;
float best_err = fabsf(ax - pos_lut[0]);
#pragma unroll
for (int i = 1; i < 8; ++i) {
const float err = fabsf(ax - pos_lut[i]);
if (err < best_err) {
best_err = err;
best_i = i;
}
}
return static_cast<uint8_t>(best_i | sign_bit);
}
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
// Precompute mp (m' in the paper) and L such that division
// can be computed using a multiply (high 32b of 64b result)

View File

@ -5,7 +5,7 @@
#include "ggml.h"
#ifdef GGML_CUDA_USE_CUB
# include <cub/device/device_scan.cuh>
# include <cub/block/block_scan.cuh>
#endif // GGML_CUDA_USE_CUB
template<typename T, int BLOCK_SIZE>
@ -16,12 +16,14 @@ static __global__ void cumsum_cub_kernel(
const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t s1, const int64_t s2, const int64_t s3) {
#ifdef GGML_CUDA_USE_CUB
using BlockScan = cub::BlockScan<T, BLOCK_SIZE>;
using BlockScanT = cub::BlockScan<T, BLOCK_SIZE>;
__shared__ typename BlockScan::TempStorage temp_storage;
__shared__ T block_carry; // carry from previous tile
__shared__ typename BlockScanT::TempStorage temp_storage;
__shared__ T block_carry;
const int tid = threadIdx.x;
constexpr int UNROLL_FACTOR = 4;
constexpr int TILE_SIZE = BLOCK_SIZE * UNROLL_FACTOR;
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.y;
@ -39,29 +41,38 @@ static __global__ void cumsum_cub_kernel(
}
__syncthreads();
for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) {
int64_t idx = start + tid;
T x = (idx < ne00) ? src_row[idx] : T(0);
for (int64_t start = 0; start < ne00; start += TILE_SIZE) {
T items[UNROLL_FACTOR];
T thread_sum = T(0);
T inclusive;
T block_total;
BlockScan(temp_storage).InclusiveSum(x, inclusive, block_total);
__syncthreads();
T final_val = inclusive + block_carry;
// store result
if (idx < ne00) {
dst_row[idx] = final_val;
#pragma unroll
for (int i = 0; i < UNROLL_FACTOR; i++) {
int64_t idx = start + tid * UNROLL_FACTOR + i;
T val = (idx < ne00) ? src_row[idx] : T(0);
thread_sum += val;
items[i] = thread_sum;
}
// Block-wide scan on thread sums
T thread_prefix;
T block_total;
BlockScanT(temp_storage).InclusiveSum(thread_sum, thread_prefix, block_total);
__syncthreads();
// Add offset to each item and store
T thread_offset = thread_prefix - thread_sum + block_carry;
#pragma unroll
for (int i = 0; i < UNROLL_FACTOR; i++) {
int64_t idx = start + tid * UNROLL_FACTOR + i;
if (idx < ne00) {
dst_row[idx] = items[i] + thread_offset;
}
}
// Update carry for next tile
if (tid == 0) {
block_carry += block_total;
}
__syncthreads();
}
#else
@ -69,7 +80,7 @@ static __global__ void cumsum_cub_kernel(
#endif // GGML_CUDA_USE_CUB
}
// Fallback kernel implementation (original)
// Fallback kernel implementation
template<typename T>
static __global__ void cumsum_kernel(
const T * src, T * dst,
@ -86,10 +97,10 @@ static __global__ void cumsum_kernel(
const int warps_per_block = blockDim.x / warp_size;
extern __shared__ float smem[];
float * s_vals = smem;
float * s_warp_sums = smem + blockDim.x;
float * s_carry = smem + blockDim.x + warps_per_block;
float * s_chunk_total = s_carry + 1;
float * s_vals = smem;
float * s_warp_sums = smem + blockDim.x;
float * s_carry = smem + blockDim.x + warps_per_block;
float * s_chunk_total = s_carry + 1;
// Initialize carry
if (tid == 0) {
@ -107,21 +118,39 @@ static __global__ void cumsum_kernel(
const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
for (int64_t start = 0; start < ne00; start += blockDim.x) {
int64_t idx = start + tid;
float val = (idx < ne00) ? ggml_cuda_cast<float, T>(src_row[idx]) : 0.0f;
// register blocking: process 4 elements per thread to hide latency
// and reduce synchronization overhead
constexpr int num_unroll = 4;
T temp[num_unroll];
// 1. Warp inclusive scan
for (int64_t i = 0; i < ne00; i += num_unroll * blockDim.x) {
int64_t idx = i + tid * num_unroll;
// thread local sequential scan
temp[0] = (idx < ne00 ? src_row[idx] : T(0));
#pragma unroll
for (int64_t j = 1; j < num_unroll; j++) {
temp[j] = temp[j - 1];
if (idx + j < ne00) {
temp[j] += src_row[idx + j];
} else {
temp[j] += 0;
}
}
// last emenent is sum of all values assigned to thread
float val = (idx < ne00) ? ggml_cuda_cast<float, T>(temp[num_unroll - 1]) : 0.0f;
// Warp inclusive scan
val = warp_prefix_inclusive_sum<T, warp_size>(val);
s_vals[tid] = val;
// Store warp total
if (lane == warp_size - 1) {
s_warp_sums[warp] = val;
}
__syncthreads();
// 2. Exclusive scan of warp sums (warp 0 only)
// Exclusive scan of warp sums (warp 0 only)
if (warp == 0) {
float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
float inc = warp_prefix_inclusive_sum<T, warp_size>(w);
@ -134,12 +163,17 @@ static __global__ void cumsum_kernel(
}
__syncthreads();
// write back results
float carry = *s_carry;
float final_val = s_vals[tid] + s_warp_sums[warp] + carry;
if (idx < ne00) {
dst_row[idx] = ggml_cuda_cast<T, float>(final_val);
// calculate sum offset for this thread
float final_val_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1];
#pragma unroll
for (int32_t j = 0; j < num_unroll; j++) {
if (idx + j < ne00) {
dst_row[idx + j] = temp[j] + ggml_cuda_cast<T, float>(final_val_offset);
}
}
__syncthreads();
// Update carry for next chunk
if (tid == 0) {
@ -177,7 +211,7 @@ static void cumsum_cuda(
const int warps_per_block = block_size / warp_size;
const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
if (use_cub) {
if (use_cub && ne00 >= 1024) {
cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,

View File

@ -900,6 +900,27 @@ namespace ggml_cuda_mma {
#endif // AMPERE_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
const tile<16, 8, int> & A,
const tile<8, 8, int> & B,
uint32_t a_scale,
uint32_t b_scale) {
#ifdef BLACKWELL_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
float * Dxi = (float *) D.x;
asm volatile(
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
"%10, {0, 0}, %11, {0, 0};"
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
#else
GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
#endif // BLACKWELL_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
#ifdef TURING_MMA_AVAILABLE

View File

@ -1,3 +1,4 @@
#include "common.cuh"
#include "mmq.cuh"
#include "quantize.cuh"
#include "mmid.cuh"
@ -114,6 +115,9 @@ void ggml_cuda_mul_mat_q(
const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| GGML_CUDA_CC_IS_CDNA(cc);
// TODO: tighter pool buffer size vs q8 path
const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
if (!ids) {
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
@ -123,12 +127,24 @@ void ggml_cuda_mul_mat_q(
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
if (use_native_mxfp4) {
static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
ne11, ne12, ne13, stream);
} else {
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
ne11, ne12, ne13, stream);
}
CUDA_CHECK(cudaGetLastError());
}
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
// Stride depends on quantization format
const int64_t s12 = use_native_mxfp4 ?
ne11 * ne10_padded * sizeof(block_fp4_mmq) /
(8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32)
:
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
const int64_t s13 = ne12*s12;
const mmq_args args = {
@ -175,12 +191,19 @@ void ggml_cuda_mul_mat_q(
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[2] / ts_src1;
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type,
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
if (use_native_mxfp4) {
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
} else {
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
}
CUDA_CHECK(cudaGetLastError());
}
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
const int64_t s13 = ne12*s12;
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.

View File

@ -11,6 +11,7 @@ using namespace ggml_cuda_mma;
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
#define MMQ_ITER_K 256
#define MMQ_ITER_K_MXFP4_FP4 512
#define MMQ_NWARPS 8
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
@ -44,8 +45,15 @@ struct block_q8_1_mmq {
};
int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
};
struct block_fp4_mmq {
uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
};
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected block_fp4_mmq size");
static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
switch (type_x) {
@ -129,6 +137,14 @@ static int get_mmq_y_host(const int cc) {
((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
}
static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
#if defined(BLACKWELL_MMA_AVAILABLE)
return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
#else
return MMQ_ITER_K;
#endif // defined(BLACKWELL_MMA_AVAILABLE)
}
static constexpr __device__ int get_mmq_y_device() {
#if defined(GGML_USE_HIP)
#if defined(RDNA1)
@ -191,6 +207,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
}
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4)
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
@ -201,6 +218,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
switch (type) {
@ -209,6 +228,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
// tile sizes are the same for Q8_1 and FP4 for blackwell
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
@ -228,7 +248,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
}
// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_1)
#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
@ -761,6 +782,50 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
}
}
template <int mmq_y, bool need_check>
static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x,
int * __restrict__ x_tile,
const int kbx0,
const int i_max,
const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
int * x_qs = (int *) x_tile;
uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
const int txi = threadIdx.x;
constexpr int iter_k = get_iter_k(GGML_TYPE_MXFP4);
constexpr int threads_per_row = iter_k / QK_MXFP4; // each thread processes 1 block
constexpr int rows_per_warp = warp_size / threads_per_row;
const int kbx = txi % threads_per_row;
const int row_in_warp = txi / threads_per_row;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
if constexpr (need_check) {
i = min(i, i_max);
}
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
// quantize_mxfp4_mmq permutes nibbles to match the quantized format
const int k0 = kbx * 4;
memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
// Load E8M0 scales: pack 2 consecutive scales into one uint32
if (kbx % 2 == 0) {
uint32_t e = bxi->e;
e |= ((bxi + 1)->e << 8);
x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
}
}
}
template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@ -931,6 +996,78 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
const int * __restrict__ y,
float * __restrict__ sum,
const int k00) {
typedef tile<16, 8, int> tile_A;
typedef tile<8, 8, int> tile_B;
typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
// Match layout from load_tiles_mxfp4_fp4
const int * x_qs = (const int *) x;
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
const int * y_qs = (const int *) y + 4;
const uint32_t * y_sc = (const uint32_t *) y;
// tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
// Block scale
// Each thread has to point to a 4 byte scale value
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
const int k0 = k00 + k01;
load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
MMQ_MMA_TILE_X_K_FP4);
// based on block-scaling document, 2 threads in each quad need to supply to the scale value
const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
scaleA[n][k01 / (2 * QI_MXFP4)] =
*(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
}
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
#pragma unroll
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
tile_B B;
uint32_t scaleB; // 2xN scales
load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
tile_C C;
mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
}
}
}
}
}
template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@ -3109,8 +3246,13 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
#ifdef BLACKWELL_MMA_AVAILABLE
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
#else
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
#endif // BLACKWELL_MMA_AVAILABLE
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};
@ -3243,17 +3385,26 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
#if defined(BLACKWELL_MMA_AVAILABLE)
// FP4 tile stores 8 blocks
constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
#else
constexpr int ne_block = 4 * QK8_1;
#endif // defined(BLACKWELL_MMA_AVAILABLE)
constexpr int ITER_K = get_iter_k(type);
constexpr int blocks_per_iter = ITER_K / qk;
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int);
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
{
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
const int * by0 = y + ncols_y * (kb0 * qk / ne_block) * sz;
#pragma unroll
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
tile_y[l] = by0[l];
@ -3267,9 +3418,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
__syncthreads();
{
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
const int * by0 = y + ncols_y * ((kb0 * qk / ne_block) * sz + sz);
#pragma unroll
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
tile_y[l] = by0[l];
@ -3401,8 +3552,10 @@ static __global__ void mul_mat_q(
}
#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
constexpr int ITER_K = get_iter_k(type);
const int64_t blocks_per_ne00 = ncols_x / qk;
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
constexpr int blocks_per_iter = ITER_K / qk;
// kbc == k block continuous, current index in continuous ijk space.
int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
@ -3463,7 +3616,7 @@ static __global__ void mul_mat_q(
__syncthreads();
}
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
offset_dst += it*mmq_y;
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
@ -3530,7 +3683,7 @@ static __global__ void mul_mat_q(
__syncthreads();
}
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
offset_dst += it*mmq_y;
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
@ -3553,7 +3706,9 @@ static __global__ void mul_mat_q_stream_k_fixup(
const int ncols_max) {
constexpr int mmq_y = get_mmq_y_device();
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
constexpr int ITER_K = get_iter_k(type);
constexpr int blocks_per_iter = ITER_K / qk;
const int64_t blocks_per_ne00 = ncols_x / qk;
constexpr int nwarps = mmq_get_nwarps_device();
@ -3711,7 +3866,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
const size_t nbs_ids = mmq_x*sizeof(int);
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq));
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
}

View File

@ -47,6 +47,131 @@ static __global__ void quantize_q8_1(
y[ib].ds = make_half2(d, sum);
}
__device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) {
if (!(amax > 0.0f)) {
return 0;
}
// FP4 E2M1: max exponent (unbiased) is 2.
constexpr int FP4_E2M1_EMAX = 2;
const float e = log2f(amax);
// "even" -> round-to-nearest integer, ties-to-even
const int e_int = __float2int_rn(e);
const int shared_exp = e_int - FP4_E2M1_EMAX;
int biased = shared_exp + 127;
biased = max(biased, 0);
biased = min(biased, 254);
return static_cast<uint8_t>(biased);
}
// quantize values in the format mxfp4 is stored which is interleaved nibbles
// i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31
static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
const int32_t * __restrict__ ids,
void * __restrict__ vy,
const int64_t ne00,
const int64_t s01,
const int64_t s02,
const int64_t s03,
const int64_t ne0,
const int ne1,
const int ne2) {
constexpr int vals_per_scale = 32;
constexpr int vals_per_warp = 2 * vals_per_scale; // Each warp processes 2 blocks of 32 = 64 values
const int warp_id = threadIdx.y;
const int lane_id_32 = threadIdx.x;
const int nwarps = blockDim.y;
const int64_t warp_start_offset = (blockIdx.y * nwarps + warp_id) * vals_per_warp;
if (warp_start_offset >= ne0) {
return;
}
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.z % ne2;
const int64_t i3 = blockIdx.z / ne2;
const int64_t i01 = ids ? ids[i1] : i1;
const int64_t i02 = i2;
const int64_t i03 = i3;
block_fp4_mmq * y = (block_fp4_mmq *) vy;
const int64_t block_fp4_mmq_size = 8 * QK_MXFP4; // 256 values
const int64_t ib0 = blockIdx.z * ((int64_t) ne1 * (ne0 / block_fp4_mmq_size));
const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x;
const int64_t quad_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp;
const int group_id = lane_id_32 / 4;
const int lane_in_group = lane_id_32 % 4;
const int base = group_id * 2;
char2 * yqs2 = (char2 *) y[ib].qs;
const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01;
uint8_t scales[2];
#pragma unroll
for (int b = 0; b < 2; ++b) {
const int64_t i0 = warp_start_offset + b * vals_per_scale + lane_id_32;
const float xi = (i0 < ne00) ? x[base_pos + i0] : 0.0f;
float amax = fabsf(xi);
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
}
const uint8_t e = compute_e8m0_scale(amax);
scales[b] = e;
const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e));
#if CUDART_VERSION >= 12080
const float scaled_val = xi * inv_s;
const float val0 = __shfl_sync(0xFFFFFFFF, scaled_val, base, WARP_SIZE);
const float val1 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 16, WARP_SIZE);
const float val2 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 1, WARP_SIZE);
const float val3 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 17, WARP_SIZE);
if (lane_in_group == 0) {
__nv_fp4x4_e2m1 fp4_packed(make_float4(val0, val1, val2, val3));
yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = *(char2 *) &fp4_packed;
}
#else
// Fallback: manual FP4 conversion using LUT
const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);
const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base, WARP_SIZE);
const uint8_t q_lo_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 1, WARP_SIZE);
const uint8_t q_hi_0 = __shfl_sync(0xFFFFFFFF, q_val, base + 16, WARP_SIZE);
const uint8_t q_hi_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 17, WARP_SIZE);
if (lane_in_group == 0) {
char2 q;
q.x = (q_hi_0 << 4) | q_lo_0;
q.y = (q_hi_1 << 4) | q_lo_1;
yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = q;
}
#endif // CUDART_VERSION >= 12080
}
if (lane_id_32 == 0) {
// Store 2 scales packed into 1 uint32
y[ib].d4[quad_idx_in_block] = (scales[1] << 8) | scales[0];
}
}
template <mmq_q8_1_ds_layout ds_layout>
static __global__ void quantize_mmq_q8_1(
const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
@ -190,3 +315,29 @@ void quantize_mmq_q8_1_cuda(
break;
}
}
void quantize_mmq_mxfp4_cuda(const float * x,
const int32_t * ids,
void * vy,
[[maybe_unused]] const ggml_type type_src0,
const int64_t ne00,
const int64_t s01,
const int64_t s02,
const int64_t s03,
const int64_t ne0,
const int64_t ne1,
const int64_t ne2,
const int64_t ne3,
cudaStream_t stream) {
GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
constexpr int nwarps = 8;
constexpr int vals_per_warp = 2 * QK_MXFP4;
constexpr int vals_per_block = nwarps * vals_per_warp;
const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
const dim3 block_size(WARP_SIZE, nwarps, 1);
quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
}

View File

@ -25,3 +25,17 @@ void quantize_mmq_q8_1_cuda(
const float * x, const int32_t * ids, void * vy,
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
void quantize_mmq_mxfp4_cuda(const float * x,
const int32_t * ids,
void * vy,
ggml_type type_src0,
int64_t ne00,
int64_t s01,
int64_t s02,
int64_t s03,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3,
cudaStream_t stream);

View File

@ -10,6 +10,10 @@
#include <cuda_fp8.h>
#endif // CUDART_VERSION >= 12050
#if CUDART_VERSION >= 12080
#include <cuda_fp4.h>
#endif // CUDART_VERSION >= 12080
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH

View File

@ -24,10 +24,6 @@
#include <arm_neon.h>
#endif
#if defined(__F16C__)
#include <immintrin.h>
#endif
#ifdef __cplusplus
extern "C" {
#endif

View File

@ -263,6 +263,32 @@ static ggml_cl_compiler_version get_adreno_cl_compiler_version(const char *drive
return { type, major, minor, patch };
}
// cl buffer wrapper
struct ggml_cl_buffer {
cl_mem buffer;
size_t size;
ggml_cl_buffer()
: buffer(nullptr), size(0) {}
~ggml_cl_buffer() {
if (buffer) {
CL_CHECK(clReleaseMemObject(buffer));
}
}
void allocate(cl_context context, size_t new_size) {
if (new_size > size) {
size = new_size;
if (buffer) {
CL_CHECK(clReleaseMemObject(buffer));
}
cl_int err;
CL_CHECK((buffer = clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, &err), err));
}
}
};
// Profiling
struct ProfilingInfo {
std::string op_name;
@ -376,6 +402,11 @@ struct ggml_backend_opencl_context {
cl_context context;
cl_command_queue queue;
// prealloc buffers for transposing weights and activations
ggml_cl_buffer prealloc_quant_trans;
ggml_cl_buffer prealloc_scales_trans;
ggml_cl_buffer prealloc_act_trans;
cl_program program_add;
cl_program program_add_id;
cl_program program_clamp;
@ -638,10 +669,6 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_transpose_16_buf;
cl_kernel kernel_transpose_16_4x1;
cl_mem A_s_d_max; // max scale buffer size for transpose
cl_mem A_q_d_max; // max weight buffer size for transpose
cl_mem B_d_max; // max activation buffer size for transpose
// Gemm and Gemv related programs, kernels, etc
cl_program program_CL_gemm;
cl_program program_CL_gemv_general;
@ -2600,9 +2627,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
required_B_d_bytes, max_B_d_bytes);
}
CL_CHECK((backend_ctx->A_q_d_max = clCreateBuffer(context, 0, max_A_q_d_bytes, NULL, &err), err));
CL_CHECK((backend_ctx->A_s_d_max = clCreateBuffer(context, 0, max_A_s_d_bytes, NULL, &err), err));
CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err));
backend_ctx->prealloc_quant_trans.allocate(context, max_A_q_d_bytes);
backend_ctx->prealloc_scales_trans.allocate(context, max_A_s_d_bytes);
backend_ctx->prealloc_act_trans.allocate(context, max_B_d_bytes);
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr;
@ -3607,32 +3634,35 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
// use sub_buffer of max buffer size instead
size_t q_size_bytes = K * M / 8 * sizeof(float);
backend_ctx->prealloc_quant_trans.allocate(context, q_size_bytes);
cl_buffer_region region;
region.origin = 0;
region.size = q_size_bytes;
cl_mem qT_d = clCreateSubBuffer(
backend_ctx->A_q_d_max,
backend_ctx->prealloc_quant_trans.buffer,
0,
CL_BUFFER_CREATE_TYPE_REGION,
&region,
&err);
// cl_mem qT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, q_size_bytes, NULL, &err);
CL_CHECK(err);
bool K_tile_trans = true;
if ((K / 32) % 4 != 0){
K_tile_trans =false;
}
size_t d_size_bytes = M * (K / 32) * 2;
backend_ctx->prealloc_scales_trans.allocate(context, d_size_bytes);
region.origin = 0;
region.size = d_size_bytes;
cl_mem dT_d = clCreateSubBuffer(
backend_ctx->A_s_d_max,
backend_ctx->prealloc_scales_trans.buffer,
0,
CL_BUFFER_CREATE_TYPE_REGION,
&region,
&err);
// cl_mem dT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, d_size_bytes, NULL, &err);
CL_CHECK(err);
// <----------------------------------------------------------------------------------> //
@ -7395,8 +7425,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
region.origin = 0;
// Specify the size of the sub-buffer (divide by 2 for FP16)
region.size = K * (N + padding) * sizeof(float)/2;
backend_ctx->prealloc_act_trans.allocate(context, region.size);
B_d = clCreateSubBuffer(
backend_ctx->B_d_max,
backend_ctx->prealloc_act_trans.buffer,
0,
CL_BUFFER_CREATE_TYPE_REGION,
&region,

View File

@ -379,18 +379,18 @@ enum FaCodePath {
};
struct vk_fa_pipeline_state {
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc)
: HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {}
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc)
: HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {}
uint32_t HSK, HSV;
bool small_rows;
bool small_rows, small_cache;
FaCodePath path;
bool aligned;
bool f32acc;
bool operator<(const vk_fa_pipeline_state &b) const {
return std::tie(HSK, HSV, small_rows, path, aligned, f32acc) <
std::tie(b.HSK, b.HSV, b.small_rows, b.path, b.aligned, b.f32acc);
return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) <
std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc);
}
};
@ -651,7 +651,7 @@ struct vk_device_struct {
vk_pipeline pipeline_add_id_f32;
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32;
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32;
vk_pipeline pipeline_scale_f32;
vk_pipeline pipeline_sqr_f32;
vk_pipeline pipeline_sqrt_f32;
@ -763,6 +763,7 @@ struct vk_device_struct {
std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
vk_pipeline pipeline_flash_attn_split_k_reduce;
vk_pipeline pipeline_count_experts;
// [2] is for whether to take n_experts from spec constant (0) or push constant (1)
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2];
@ -1004,6 +1005,14 @@ struct vk_op_push_constants {
float param4;
};
struct vk_op_count_experts_push_constants {
uint32_t ne00;
uint32_t ne01;
uint32_t nb00;
uint32_t nb01;
uint32_t a_offset;
};
struct vk_op_glu_push_constants {
uint32_t N;
uint32_t ne00;
@ -1192,6 +1201,7 @@ struct vk_op_diag_mask_push_constants {
struct vk_op_rope_push_constants {
uint32_t rope_mode;
uint32_t ncols;
uint32_t nrows;
uint32_t n_dims;
float freq_scale;
uint32_t p_delta_rows;
@ -1564,7 +1574,7 @@ class vk_perf_logger {
total_op_times += time;
}
std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0)
<< " us";
<< " us = " << (total_op_times / 1000.0) << " us";
// If we have as many flops entries as timing entries for the op, then compute and log the flops/S.
auto it = flops.find(t.first);
@ -2582,10 +2592,10 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
static constexpr uint32_t flash_attention_num_small_rows = 32;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) {
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) {
if (hsv >= 192) {
return 2;
} else if ((hsv | hsk) & 8) {
} else if ((hsv | hsk) & 8 || small_cache) {
return 4;
} else {
return 8;
@ -2607,9 +2617,8 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
}
}
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) {
GGML_UNUSED(clamp);
GGML_UNUSED(hsv);
if (path == FA_SCALAR) {
if (small_rows) {
@ -2618,9 +2627,9 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
if ((hsv | hsk) & 8) {
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
return {get_fa_scalar_num_large_rows(hsk, hsv), 64};
return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64};
} else {
return {get_fa_scalar_num_large_rows(hsk, hsv), 32};
return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32};
}
}
}
@ -2649,8 +2658,8 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
return {64, 64};
}
static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) {
return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1];
static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) {
return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1];
}
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
@ -2830,9 +2839,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
s_mmq_wg_denoms_k = { 32, 64, 1 };
// spec constants and tile sizes for quant matmul_id
l_warptile_mmqid = { 256, 128, 128, 16, 1, device->subgroup_size };
m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
l_warptile_mmqid = { 256, 128, 128, 32, 1, device->subgroup_size };
m_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
s_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
l_mmqid_wg_denoms = { 128, 128, 1 };
m_mmqid_wg_denoms = { 128, 64, 1 };
s_mmqid_wg_denoms = { 128, 64, 1 };
@ -2992,11 +3001,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
align, disable_robustness, require_full_subgroups, required_subgroup_size);
};
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array<uint32_t, 3> {
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
};
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector<uint32_t> {
// For large number of rows, 128 invocations seems to work best.
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
// can't use 256 for D==80.
@ -3006,7 +3015,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
? scalar_flash_attention_workgroup_size
: ((small_rows && (D % 32) == 0) ? 256 : 128);
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache);
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
@ -3021,21 +3030,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
uint32_t HSK = fa.first.HSK; \
uint32_t HSV = fa.first.HSV; \
bool small_rows = fa.first.small_rows; \
bool small_cache = fa.first.small_cache; \
FaCodePath path = fa.first.path; \
bool aligned = fa.first.aligned; \
bool f32acc = fa.first.f32acc; \
if (path == FAPATH) { \
if (aligned) { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} \
} else { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} \
} \
} \
@ -3067,17 +3077,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
#endif
#undef CREATE_FA
const int mul_mat_id_param_count = 5;
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true); \
// Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
@ -3113,32 +3125,32 @@ static void ggml_vk_load_shaders(vk_device& device) {
GGML_ASSERT(device->subgroup_ballot);
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
if (device->coopmat_bf16_support) {
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)
}
#endif
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
#undef CREATE_MM
#undef CREATE_MM2
} else
@ -3227,35 +3239,35 @@ static void ggml_vk_load_shaders(vk_device& device) {
GGML_ASSERT(device->subgroup_ballot);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id);
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
if (device->coopmat_bf16_support) {
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id);
}
#endif
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
#undef CREATE_MM2
#undef CREATE_MM
} else
@ -3340,91 +3352,91 @@ static void ggml_vk_load_shaders(vk_device& device) {
#endif
if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
}
#endif
} else {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
}
#endif
}
@ -3501,57 +3513,57 @@ static void ggml_vk_load_shaders(vk_device& device) {
#endif
if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
} else {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
}
}
// reusing CREATE_MM from the fp32 path
@ -3570,7 +3582,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
s_wg_denoms = { 32, 32, 1 };
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
}
#undef CREATE_MM
@ -3955,6 +3967,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
ggml_vk_create_pipeline(device, device->pipeline_upscale_bicubic_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BICUBIC}, 1);
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_antialias_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS}, 1);
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@ -4126,6 +4139,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_count_experts, "count_experts", count_experts_len, count_experts_data, "main", 2, sizeof(vk_op_count_experts_push_constants), {1, 1, 1}, {}, 1, true);
for (auto &s : device->pipeline_solve_tri_f32) {
const vk_solve_tri_pipeline_state &state = s.first;
@ -6523,18 +6538,18 @@ static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context *
static void ggml_vk_matmul_id(
ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, const vk_subbuffer & expert_count_buf,
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
uint32_t padded_n) {
VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), expert_count: (" << expert_count_buf.buffer->buffer << ", " << expert_count_buf.offset << ", " << expert_count_buf.size << "), " <<
"m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
"batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
"n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
nei0, nei1, nbi1, ne11, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as });
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids, expert_count_buf }, pc, { m, nei1, n_as });
}
static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
@ -7517,6 +7532,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t nei0 = ids->ne[0];
const uint64_t nei1 = ids->ne[1];
const uint32_t nbi0 = ids->nb[0];
const uint32_t nbi1 = ids->nb[1];
const uint32_t nbi2 = ids->nb[2];
@ -7624,6 +7640,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (quantize_y) {
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
}
vk_pipeline count_experts = ctx->device->pipeline_count_experts;
uint32_t expert_count_size = sizeof(uint32_t) * n_as;
{
if (
@ -7639,6 +7658,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
ctx->prealloc_size_y = y_sz;
ggml_vk_preallocate_buffers(ctx, subctx);
}
if (ctx->prealloc_size_split_k < expert_count_size) {
ctx->prealloc_size_split_k = expert_count_size;
ggml_vk_preallocate_buffers(ctx, subctx);
}
// Request descriptor sets
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
@ -7651,6 +7674,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (quantize_y) {
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
}
ggml_pipeline_request_descriptor_sets(ctx, count_experts, 1);
}
vk_buffer d_D = dst_buf_ctx->dev_buffer;
@ -7700,6 +7724,20 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
ggml_vk_sync_buffers(ctx, subctx);
}
}
// Count how many times each expert is used
vk_subbuffer expert_count_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
if (ctx->prealloc_split_k_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
{
const std::vector<uint32_t> pc = { (uint32_t)nei0,
(uint32_t)nei1,
(uint32_t)(nbi0 / ggml_type_size(ids->type)),
(uint32_t)(nbi1 / ggml_type_size(ids->type)),
(uint32_t)(get_misalign_bytes(ctx, ids) / ggml_type_size(ids->type)) };
ggml_vk_dispatch_pipeline(ctx, subctx, count_experts,
{ vk_subbuffer{ d_ids, ids_buf_offset, ids_sz }, expert_count_buf }, pc, { (uint32_t)n_as, 1, 1});
}
if (x_non_contig) {
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
@ -7707,7 +7745,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_X, 0, x_sz } }, pc, { (uint32_t)x_ne, 1, 1});
ggml_vk_sync_buffers(ctx, subctx);
}
if (y_non_contig) {
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
@ -7731,6 +7768,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
ctx->prealloc_y_last_tensor_used = src1;
}
}
ggml_vk_sync_buffers(ctx, subctx);
uint32_t stride_batch_x = ne00*ne01;
uint32_t stride_batch_y = ne10*ne11;
@ -7747,7 +7785,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
ggml_vk_matmul_id(
ctx, subctx, pipeline,
{ d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz },
{ d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz },
{ d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, expert_count_buf,
ne01, ne21, ne10, ne10, ne10, ne01,
stride_batch_x, stride_batch_y, ne20*ne21,
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
@ -7759,6 +7797,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (y_non_contig || quantize_y) {
ctx->prealloc_y_need_sync = true;
}
ctx->prealloc_split_k_need_sync = true;
}
static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
@ -8008,11 +8047,11 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) {
// Needs to be kept up to date on shader changes
GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv);
const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache);
const uint32_t Bc = scalar_flash_attention_Bc;
const uint32_t tmpsh = wg_size * sizeof(float);
@ -8136,6 +8175,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
uint32_t workgroups_y = (uint32_t)neq2;
uint32_t workgroups_z = (uint32_t)neq3;
const bool small_cache = nek1 < 1024;
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
uint32_t max_gqa;
@ -8143,7 +8184,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
case FA_SCALAR:
case FA_COOPMAT1:
// We may switch from coopmat1 to scalar, so use the scalar limit for both
max_gqa = get_fa_scalar_num_large_rows(HSK, HSV);
max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache);
break;
case FA_COOPMAT2:
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
@ -8177,7 +8218,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
if (path == FA_SCALAR &&
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) {
small_rows = true;
}
@ -8193,7 +8234,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
v_stride /= 4;
}
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache);
bool aligned = (KV % alignment) == 0 &&
// the "aligned" shader variant will forcibly align strides, for performance
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
@ -8205,7 +8246,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc);
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc);
vk_pipeline pipeline = nullptr;
@ -8430,7 +8471,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
case GGML_OP_UPSCALE:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(dst, 0) & 0xFF);
uint32_t mode = (ggml_get_op_params_i32(dst, 0) & (0xFF | GGML_SCALE_FLAG_ANTIALIAS));
switch (mode) {
case GGML_SCALE_MODE_NEAREST:
return ctx->device->pipeline_upscale_nearest_f32;
@ -8438,6 +8479,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_upscale_bilinear_f32;
case GGML_SCALE_MODE_BICUBIC:
return ctx->device->pipeline_upscale_bicubic_f32;
case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS:
return ctx->device->pipeline_upscale_bilinear_antialias_f32;
default:
return nullptr;
}
@ -9088,10 +9131,20 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
} break;
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
break;
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
{
uint32_t nrows = (uint32_t)ggml_nrows(src0);
uint32_t z = 1;
if (nrows > ctx->device->properties.limits.maxComputeWorkGroupCount[0]) {
z = CEIL_DIV(nrows, 32768);
nrows = 32768;
}
elements = { nrows, (uint32_t)ne00, z };
} break;
case GGML_OP_GET_ROWS:
elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
@ -10019,7 +10072,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *
uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
vk_op_rope_push_constants rope {
(uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
(uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
has_ff, (uint32_t)src0->ne[2], nb01, nb02,
{ sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
@ -13716,6 +13769,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
}
static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
VK_LOG_DEBUG("ggml_backend_vk_event_record(backend=" << backend << ", event=" << event << ")");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
vk_event *vkev = (vk_event *)event->context;
@ -13745,6 +13799,7 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev
}
static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
VK_LOG_DEBUG("ggml_backend_vk_event_wait(backend=" << backend << ", event=" << event << ")");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
vk_event *vkev = (vk_event *)event->context;
@ -13760,6 +13815,8 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even
}
ggml_vk_wait_events(transfer_ctx, {vkev->event});
ggml_vk_ctx_end(transfer_ctx);
ctx->transfer_ctx.reset();
}
// TODO: enable async and synchronize
@ -14324,7 +14381,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
}
return true;
case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) {
if ((op->op_params[0] & 0xFF) != GGML_SCALE_MODE_BILINEAR) {
return false;
}
}
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_CONCAT:
@ -14519,6 +14581,7 @@ static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backe
}
static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
VK_LOG_DEBUG("ggml_backend_vk_device_event_synchronize(backend=" << dev << ", event=" << event << ")");
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
vk_event *vkev = (vk_event *)event->context;

View File

@ -0,0 +1,51 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#include "types.glsl"
layout (push_constant) uniform parameter
{
uint32_t ne00;
uint32_t ne01;
uint32_t nb00;
uint32_t nb01;
uint32_t a_offset;
} p;
#define BLOCK_SIZE 256
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {uint data_a[];};
layout (binding = 1) writeonly buffer D {uint data_d[];};
shared uint vals[BLOCK_SIZE];
void main() {
const uint expert_id = gl_WorkGroupID.x;
const uint num_elements = p.ne00 * p.ne01;
const uint tid = gl_LocalInvocationID.x;
uint count = 0;
for (uint idx = tid; idx < num_elements; idx += BLOCK_SIZE) {
const uint i01 = idx / p.ne00;
const uint i00 = idx % p.ne00;
const uint a = data_a[p.a_offset + i01 * p.nb01 + i00 * p.nb00];
count += uint(a == expert_id);
}
vals[tid] = count;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
vals[tid] += vals[tid + s];
}
barrier();
}
if (tid == 0) {
data_d[expert_id] = vals[0];
}
}

View File

@ -401,13 +401,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;
const uint qshift = (iqs & 16) >> 2;
u8vec4 qs = u8vec4(
data_a[a_offset + ib].qs[iq + 0],
data_a[a_offset + ib].qs[iq + 1],
data_a[a_offset + ib].qs[iq + 2],
data_a[a_offset + ib].qs[iq + 3]
);
qs = (qs >> qshift) & uint8_t(0xF);
const u8vec4 qs = unpack8((data_a_packed32[a_offset + ib].qs[iq/4] >> qshift) & 0x0F0F0F0F);
const float dl = float(int(sl | (sh << 4)) - 32);
return dl * vec4(

View File

@ -68,6 +68,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
layout (binding = 4) readonly buffer Counts {int data_expert_count[];};
#endif
layout (push_constant) uniform parameter
@ -135,13 +136,19 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#include "mul_mm_funcs.glsl"
void main() {
const uint ic = gl_WorkGroupID.y;
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
if (ic * BN >= data_expert_count[expert_idx]) {
return;
}
#endif
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
#else
#ifndef MUL_MAT_ID
const uint batch_idx = gl_GlobalInvocationID.z;
const uint i13 = batch_idx / p.ne12;
@ -156,7 +163,6 @@ void main() {
const uint blocks_m = (p.M + BM - 1) / BM;
const uint ir = gl_WorkGroupID.x % blocks_m;
const uint ik = gl_WorkGroupID.x / blocks_m;
const uint ic = gl_WorkGroupID.y;
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const uint WSUBM = WM / WMITER;

View File

@ -92,6 +92,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
layout (binding = 4) readonly buffer Counts {int data_expert_count[];};
shared u16vec4 row_ids[BN];
@ -107,11 +108,7 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i
{
const uint row_i = blockCoords[0];
if (row_i >= _ne1) {
return B_TYPE(0.0);
}
const u16vec4 row_idx = row_ids[row_i & (BN - 1)];
const u16vec4 row_idx = row_ids[row_i];
B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
return ret;
@ -138,6 +135,8 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
uint ids[16];
uint iter = 0;
uint expert_count = data_expert_count[expert_idx];
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
// prefetch up to 16 elements
if (iter == 0) {
@ -185,7 +184,7 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
}
_ne1 += total;
iter &= 15;
if (_ne1 >= (ic + 1) * BN) {
if (_ne1 >= (ic + 1) * BN || _ne1 == expert_count) {
break;
}
}
@ -194,15 +193,28 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
#endif
void main() {
const uint tid = gl_LocalInvocationIndex;
const uint ic = gl_WorkGroupID.y;
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
if (ic * BN >= data_expert_count[expert_idx]) {
return;
}
// initialize to row 0 so we don't need to bounds check
if (tid < BN) {
row_ids[tid] = u16vec4(0);
}
#if !defined(NEEDS_INIT_IQ_SHMEM)
barrier();
#endif
#endif
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
const uint tid = gl_LocalInvocationIndex;
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
#else
#ifndef MUL_MAT_ID
const uint batch_idx = gl_GlobalInvocationID.z;
const uint i13 = batch_idx / p.ne12;
@ -217,7 +229,6 @@ void main() {
const uint blocks_m = (p.M + BM - 1) / BM;
const uint ir = gl_WorkGroupID.x % blocks_m;
const uint ik = gl_WorkGroupID.x / blocks_m;
const uint ic = gl_WorkGroupID.y;
#ifdef MUL_MAT_ID
if (bitCount(p.nei0) == 1) {
@ -482,7 +493,7 @@ void main() {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else {
@ -490,7 +501,7 @@ void main() {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
@ -526,7 +537,7 @@ void main() {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else {
@ -534,7 +545,7 @@ void main() {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
@ -571,7 +582,7 @@ void main() {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
#else
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
#endif
@ -583,7 +594,7 @@ void main() {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
#else
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
#endif

View File

@ -159,14 +159,16 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint is = iqs / 8; // 0..15
const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3
const uint qsshift = halfsplit * 2; // 0,2,4,6
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)
| (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
const float dl = float(data_a[ib].d) * float(us - 32);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)),
dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
const vec2 qs = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> qsshift) & 0x0303).xy);
const vec2 hm = vec2(unpack8(((uint(data_a_packed16[ib].hmask[hmi / 2]) >> (4 * n + halfsplit)) & 0x0101 ^ 0x0101) << 2).xy);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * (qs.x - hm.x),
dl * (qs.y - hm.y));
#elif defined(DATA_A_Q4_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@ -198,8 +200,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m),
fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
const vec2 q = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F).xy);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
fma(d, q.y, m));
#elif defined(DATA_A_Q5_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@ -213,8 +217,6 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
const uint qhi = (iqs % 16) * 2; // 0,2,4..30
const uint8_t hm = uint8_t(1 << (iqs / 16));
const vec2 loadd = vec2(data_a[ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
@ -234,8 +236,12 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m),
fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
const uint qs = (uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F;
const uint qh = ((uint(data_a_packed16[ib].qh[qhi / 2]) >> (iqs / 16)) & 0x0101) << 4;
const vec2 q = vec2(unpack8(qs | qh).xy);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
fma(d, q.y, m));
#elif defined(DATA_A_Q6_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@ -394,11 +400,9 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const float d = float(data_a[ib].d);
const uint qs = data_a[ib].qs[iqs];
const uint signs = pack32(u8vec4(
data_a[ib].qs[is+0],
data_a[ib].qs[is+1],
data_a[ib].qs[is+2],
data_a[ib].qs[is+3]
const uint signs = pack32(u16vec2(
data_a_packed16[ib].qs[is/2],
data_a_packed16[ib].qs[is/2+1]
));
const float db = d * 0.5 * (0.5 + (signs >> 28));
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
@ -443,8 +447,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;
const uint qshift = (idx & 8) >> 1;
u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]);
qs = (qs >> qshift) & uint8_t(0xF);
u8vec2 qs = unpack8((uint(data_a_packed16[ib].qs[iq/2]) >> qshift) & 0x0F0F).xy;
const float d = float(data_a[ib].d);
const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);

View File

@ -13,6 +13,8 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
uint ids[16];
uint iter = 0;
uint expert_count = data_expert_count[expert_idx];
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
// prefetch up to 16 elements
if (iter == 0) {
@ -60,7 +62,7 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
}
_ne1 += total;
iter &= 15;
if (_ne1 >= (ic + 1) * BN) {
if (_ne1 >= (ic + 1) * BN || _ne1 == expert_count) {
break;
}
}

View File

@ -35,6 +35,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
layout (binding = 4) readonly buffer Counts {int data_expert_count[];};
#endif
layout (push_constant) uniform parameter
@ -104,13 +105,19 @@ block_b_cache cache_b;
#include "mul_mmq_funcs.glsl"
void main() {
const uint ic = gl_WorkGroupID.y;
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
if (ic * BN >= data_expert_count[expert_idx]) {
return;
}
#endif
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
#else
#ifndef MUL_MAT_ID
const uint batch_idx = gl_GlobalInvocationID.z;
const uint i13 = batch_idx / p.ne12;
@ -125,7 +132,6 @@ void main() {
const uint blocks_m = (p.M + BM - 1) / BM;
const uint ir = gl_WorkGroupID.x % blocks_m;
const uint ik = gl_WorkGroupID.x / blocks_m;
const uint ic = gl_WorkGroupID.y;
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const uint WSUBM = WM / WMITER;

View File

@ -6,6 +6,9 @@
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
// i1 is actually i2*nb2+i1, but the rows are contiguous
const uint i1 = gl_GlobalInvocationID.x;
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (i1 >= pc.nrows) {
return;
}
rope_multi(i0, i1, pc);
}

View File

@ -6,6 +6,9 @@
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
// i1 is actually i2*nb2+i1, but the rows are contiguous
const uint i1 = gl_GlobalInvocationID.x;
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (i1 >= pc.nrows) {
return;
}
rope_neox(i0, i1, pc);
}

View File

@ -6,6 +6,9 @@
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
// i1 is actually i2*nb2+i1, but the rows are contiguous
const uint i1 = gl_GlobalInvocationID.x;
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (i1 >= pc.nrows) {
return;
}
rope_norm(i0, i1, pc);
}

View File

@ -6,6 +6,7 @@
struct rope_params {
uint rope_mode;
uint ncols;
uint nrows;
uint n_dims;
float freq_scale;
uint p_delta_rows;

View File

@ -6,6 +6,9 @@
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
// i1 is actually i2*nb2+i1, but the rows are contiguous
const uint i1 = gl_GlobalInvocationID.x;
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (i1 >= pc.nrows) {
return;
}
rope_vision(i0, i1, pc);
}

View File

@ -172,16 +172,12 @@ struct block_q8_0
float16_t d;
int8_t qs[32];
};
struct block_q8_0_packed16
{
float16_t d;
int16_t qs[32/2];
};
struct block_q8_0_packed32
{
float16_t d;
int32_t qs[32/4];
};
#if defined(DATA_A_Q8_0)
#define QUANT_K QUANT_K_Q8_0
@ -189,7 +185,6 @@ struct block_q8_0_packed32
#define QUANT_AUXF 1
#define A_TYPE block_q8_0
#define A_TYPE_PACKED16 block_q8_0_packed16
#define A_TYPE_PACKED32 block_q8_0_packed32
#define DATA_A_QUANT_LEGACY
#endif
@ -201,11 +196,13 @@ struct block_q8_1
f16vec2 ds;
int8_t qs[32];
};
struct block_q8_1_packed16
{
f16vec2 ds;
int16_t qs[16];
};
struct block_q8_1_packed32
{
f16vec2 ds;
@ -218,6 +215,7 @@ struct block_q8_1_x4
f16vec2 ds[4];
int32_t qs[32];
};
struct block_q8_1_x4_packed128
{
f16vec2 ds[4];
@ -1346,10 +1344,28 @@ struct block_iq4_xs
uint8_t qs[QUANT_K_IQ4_XS/2];
};
struct block_iq4_xs_packed16
{
float16_t d;
uint16_t scales_h;
uint16_t scales_l[QUANT_K_IQ4_XS/128];
uint16_t qs[QUANT_K_IQ4_XS/4];
};
struct block_iq4_xs_packed32
{
float16_t d;
uint16_t scales_h;
uint32_t scales_l;
uint32_t qs[QUANT_K_IQ4_XS/8];
};
#if defined(DATA_A_IQ4_XS)
#define QUANT_K QUANT_K_IQ4_XS
#define QUANT_R QUANT_R_IQ4_XS
#define A_TYPE block_iq4_xs
#define A_TYPE_PACKED16 block_iq4_xs_packed16
#define A_TYPE_PACKED32 block_iq4_xs_packed32
#endif
#define QUANT_K_IQ4_NL 32

View File

@ -21,6 +21,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
#define NEAREST 0
#define BILINEAR 1
#define BICUBIC 2
#define BILINEAR_ANTIALIAS 513
layout (constant_id = 0) const uint scale_mode = 0;
@ -62,6 +63,56 @@ float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
return fetch_bilinear(c0, c1, d, i12, i13);
}
float triangle_filter(float x) {
return max(1.0f - abs(x), 0.0f);
}
float interpolate_bilinear_antialias(uint i10, uint i11, uint i12, uint i13) {
const float support1 = max(1.0f, 1.0f / p.sf1);
const float invscale1 = 1.0f / support1;
const float support0 = max(1.0f, 1.0f / p.sf0);
const float invscale0 = 1.0f / support0;
const uint i02 = uint(i12 / p.sf2);
const uint i03 = uint(i13 / p.sf3);
const float y = (float(i11) + p.pixel_offset) / p.sf1;
const float x = (float(i10) + p.pixel_offset) / p.sf0;
// the range of source pixels that contribute
const int x_min = max(int(x - support0 + p.pixel_offset), 0);
const int x_max = min(int(x + support0 + p.pixel_offset), int(p.ne00));
const int y_min = max(int(y - support1 + p.pixel_offset), 0);
const int y_max = min(int(y + support1 + p.pixel_offset), int(p.ne01));
// bilinear filter with antialiasing
float val = 0.0f;
float total_weight = 0.0f;
for (int sy = y_min; sy < y_max; sy++) {
const float weight_y = triangle_filter((sy - y + p.pixel_offset) * invscale1);
for (int sx = x_min; sx < x_max; sx++) {
const float weight_x = triangle_filter((sx - x + p.pixel_offset) * invscale0);
const float weight = weight_x * weight_y;
if (weight <= 0.0f) {
continue;
}
const float pixel = data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + sy * p.nb01 + sx * p.nb00];
val += pixel * weight;
total_weight += weight;
}
}
if (total_weight > 0.0f) {
val /= total_weight;
}
return val;
}
// Bicubic interpolation with alpha = -0.75
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
const vec4 bcoeffs1 = vec4( 1.25, -2.25, 0.0, 1.0);
@ -118,6 +169,9 @@ void main() {
case BICUBIC:
result = interpolate_bicubic(i10, i11, i12, i13);
break;
case BILINEAR_ANTIALIAS:
result = interpolate_bilinear_antialias(i10, i11, i12, i13);
break;
}
data_d[p.d_offset + idx] = D_TYPE(result);

View File

@ -945,6 +945,8 @@ void process_shaders() {
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("count_experts", "count_experts.comp", merge_maps(base_dict, {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}}));
for (std::string dim_str : {"", "_3d"}) {
for (bool bda : {false, true}) {
std::string bda_str = bda ? "_bda" : "";

View File

@ -449,6 +449,8 @@ class MODEL_ARCH(IntEnum):
RND1 = auto()
PANGU_EMBED = auto()
MISTRAL3 = auto()
MIMO2 = auto()
LLAMA_EMBED = auto()
class VISION_PROJECTOR_TYPE(IntEnum):
@ -844,6 +846,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.RND1: "rnd1",
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
MODEL_ARCH.MISTRAL3: "mistral3",
MODEL_ARCH.MIMO2: "mimo2",
MODEL_ARCH.LLAMA_EMBED: "llama-embed",
}
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@ -3196,6 +3200,46 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.MIMO2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_SINKS,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
],
MODEL_ARCH.LLAMA_EMBED: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
# TODO
}

View File

@ -320,6 +320,7 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_SINKS: (
"model.layers.{bid}.self_attn.sinks", # openai-moe
"model.layers.{bid}.self_attn.attention_sink_bias", # mimov2
),
MODEL_TENSOR.ATTN_GATE: (

View File

@ -286,7 +286,7 @@ extern "C" {
// NULL-terminated list of buffer types to use for tensors that match a pattern
const struct llama_model_tensor_buft_override * tensor_buft_overrides;
int32_t n_gpu_layers; // number of layers to store in VRAM
int32_t n_gpu_layers; // number of layers to store in VRAM, a negative value means all layers
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
// the GPU that is used for the entire model when split_mode is LLAMA_SPLIT_MODE_NONE
@ -467,10 +467,17 @@ extern "C" {
// Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx);
enum llama_params_fit_status {
LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit
LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit
LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occured, e.g. because no model could be found at the specified path
};
// fits mparams and cparams to free device memory (assumes system memory is unlimited)
// returns true if the parameters could be successfully modified to fit device memory
// this function is NOT thread safe because it modifies the global llama logger state
LLAMA_API bool llama_params_fit(
// - returns true if the parameters could be successfully modified to fit device memory
// - this function is NOT thread safe because it modifies the global llama logger state
// - only parameters that have the same value as in llama_default_model_params are modified
LLAMA_API enum llama_params_fit_status llama_params_fit(
const char * path_model,
struct llama_model_params * mparams,
struct llama_context_params * cparams,

View File

@ -88,6 +88,7 @@ add_library(llama
models/llama-iswa.cpp
models/llama.cpp
models/mamba.cpp
models/mimo2-iswa.cpp
models/minicpm3.cpp
models/minimax-m2.cpp
models/modern-bert.cpp

View File

@ -115,6 +115,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_RND1, "rnd1" },
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
{ LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_MIMO2, "mimo2" },
{ LLM_ARCH_LLAMA_EMBED, "llama-embed" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -500,6 +502,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
case LLM_ARCH_LLAMA:
case LLM_ARCH_DECI:
case LLM_ARCH_MISTRAL3:
case LLM_ARCH_LLAMA_EMBED:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
@ -2188,6 +2191,27 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_VISEXP_FFN_DOWN,
LLM_TENSOR_VISEXP_FFN_UP,
};
case LLM_ARCH_MIMO2:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_K,
LLM_TENSOR_ATTN_V,
LLM_TENSOR_ATTN_SINKS,
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_FFN_NORM,
LLM_TENSOR_FFN_GATE,
LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_UP,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_DOWN_EXPS,
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_FFN_EXP_PROBS_B,
};
case LLM_ARCH_GPTJ:
case LLM_ARCH_UNKNOWN:
return {

View File

@ -119,6 +119,8 @@ enum llm_arch {
LLM_ARCH_RND1,
LLM_ARCH_PANGU_EMBED,
LLM_ARCH_MISTRAL3,
LLM_ARCH_MIMO2,
LLM_ARCH_LLAMA_EMBED,
LLM_ARCH_UNKNOWN,
};

View File

@ -294,8 +294,8 @@ llama_context::llama_context(
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
bool pipeline_parallel =
model.n_devices() > 1 &&
model.params.n_gpu_layers > (int) model.hparams.n_layer &&
model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
model.n_gpu_layers() > model.hparams.n_layer &&
model.split_mode() == LLAMA_SPLIT_MODE_LAYER &&
cparams.offload_kqv &&
!model.has_tensor_overrides();
@ -1570,7 +1570,7 @@ llm_graph_cb llama_context::graph_get_cb() const {
// norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
// FIXME: fix in ggml_backend_sched
const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer;
const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer;
if (ubatch.n_tokens < 32 || full_offload) {
if (il != -1 && strcmp(name, "norm") == 0) {
const auto & dev_layer = model.dev_layer(il);

View File

@ -123,10 +123,11 @@ struct llama_hparams {
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
// the size of the sliding window (0 - no SWA)
uint32_t n_swa = 0;
// if swa_layers[il] == true, then layer il is SWA
// if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
// if swa_layers[il] == 1, then layer il is SWA
// if swa_layers[il] == 0, then layer il is dense (i.e. non-SWA)
// by default, all layers are dense
std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
// note: using uint32_t type for compatibility reason
std::array<uint32_t, LLAMA_MAX_LAYERS> swa_layers;
// for State Space Models
uint32_t ssm_d_conv = 0;

View File

@ -130,6 +130,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_230B_A10B: return "230B.A10B";
case LLM_TYPE_235B_A22B: return "235B.A22B";
case LLM_TYPE_300B_A47B: return "300B.A47B";
case LLM_TYPE_310B_A15B: return "310B.A15B";
case LLM_TYPE_355B_A32B: return "355B.A32B";
case LLM_TYPE_E2B: return "E2B";
case LLM_TYPE_E4B: return "E4B";
@ -606,7 +607,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) {
if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON || arch == LLM_ARCH_LLAMA_EMBED) {
if (hparams.n_rot != hparams.n_embd_head_k) {
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
}
@ -630,6 +631,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// arch-specific KVs
switch (arch) {
case LLM_ARCH_LLAMA:
case LLM_ARCH_LLAMA_EMBED:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@ -2338,6 +2340,22 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_MIMO2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa);
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer);
switch (hparams.n_layer) {
case 48: type = LLM_TYPE_310B_A15B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
default: throw std::runtime_error("unsupported model architecture");
}
@ -2360,11 +2378,11 @@ void llama_model::load_vocab(llama_model_loader & ml) {
bool llama_model::load_tensors(llama_model_loader & ml) {
const auto & split_mode = params.split_mode;
const auto & n_gpu_layers = params.n_gpu_layers;
const auto & use_mlock = params.use_mlock;
const auto & tensor_split = params.tensor_split;
const int n_layer = hparams.n_layer;
const int n_layer = hparams.n_layer;
const int n_gpu_layers = this->n_gpu_layers();
const bool use_mmap_buffer = true;
@ -2652,6 +2670,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
case LLM_ARCH_MISTRAL3:
case LLM_ARCH_LLAMA_EMBED:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -6646,6 +6665,44 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0);
}
} break;
case LLM_ARCH_MIMO2:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i);
uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i);
uint32_t n_head = hparams.n_head(i);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_v * n_head, n_embd }, 0);
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, TENSOR_NOT_REQUIRED);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
// non-MoE branch
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
// MoE branch
int64_t n_ff_exp = hparams.n_ff_exp;
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED);
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
}
} break;
default:
throw std::runtime_error("unknown architecture");
}
@ -6827,6 +6884,14 @@ size_t llama_model::n_devices() const {
return devices.size();
}
uint32_t llama_model::n_gpu_layers() const {
return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer + 1;
}
llama_split_mode llama_model::split_mode() const {
return params.split_mode;
}
std::map<ggml_backend_buffer_type_t, size_t> llama_model::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> ret;
for (const auto & [ctx, bufs] : pimpl->ctxs_bufs) {
@ -7269,16 +7334,20 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
switch (arch) {
case LLM_ARCH_LLAMA:
{
llm = std::make_unique<llm_build_llama>(*this, params);
llm = std::make_unique<llm_build_llama<false>>(*this, params);
} break;
case LLM_ARCH_LLAMA4:
{
if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) {
llm = std::make_unique<llm_build_llama>(*this, params);
llm = std::make_unique<llm_build_llama<false>>(*this, params);
} else {
llm = std::make_unique<llm_build_llama_iswa>(*this, params);
}
} break;
case LLM_ARCH_LLAMA_EMBED:
{
llm = std::make_unique<llm_build_llama<true>>(*this, params);
} break;
case LLM_ARCH_DECI:
{
llm = std::make_unique<llm_build_deci>(*this, params);
@ -7704,6 +7773,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_mistral3>(*this, params);
} break;
case LLM_ARCH_MIMO2:
{
llm = std::make_unique<llm_build_mimo2_iswa>(*this, params);
} break;
default:
GGML_ABORT("fatal error");
}
@ -7729,7 +7802,7 @@ llama_model_params llama_model_default_params() {
llama_model_params result = {
/*.devices =*/ nullptr,
/*.tensor_buft_overrides =*/ nullptr,
/*.n_gpu_layers =*/ 999,
/*.n_gpu_layers =*/ -1,
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
/*.main_gpu =*/ 0,
/*.tensor_split =*/ nullptr,
@ -7874,6 +7947,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_ERNIE4_5:
case LLM_ARCH_ERNIE4_5_MOE:
case LLM_ARCH_MISTRAL3:
case LLM_ARCH_LLAMA_EMBED:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2
@ -7933,6 +8007,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_PANGU_EMBED:
case LLM_ARCH_AFMOE:
case LLM_ARCH_QWEN3NEXT:
case LLM_ARCH_MIMO2:
return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_QWEN2VL:

View File

@ -123,6 +123,7 @@ enum llm_type {
LLM_TYPE_230B_A10B, // Minimax M2
LLM_TYPE_235B_A22B,
LLM_TYPE_300B_A47B, // Ernie MoE big
LLM_TYPE_310B_A15B, // /MiMo-V2-Flash
LLM_TYPE_355B_A32B, // GLM-4.5
LLM_TYPE_E2B,
LLM_TYPE_E4B,
@ -465,8 +466,6 @@ struct llama_model {
struct ggml_tensor * dense_2_out_layers = nullptr;
struct ggml_tensor * dense_3_out_layers = nullptr;
llama_model_params params;
// gguf metadata
std::unordered_map<std::string, std::string> gguf_kv;
@ -497,6 +496,9 @@ struct llama_model {
size_t n_tensors() const;
size_t n_devices() const;
uint32_t n_gpu_layers() const;
llama_split_mode split_mode() const;
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const;
// total number of parameters in the model
@ -525,6 +527,8 @@ struct llama_model {
ggml_cgraph * build_graph(const llm_graph_params & params) const;
private:
llama_model_params params;
struct impl;
std::unique_ptr<impl> pimpl;
};

View File

@ -140,6 +140,10 @@ enum layer_fraction_t {
};
// this enum is only used in llama_params_fit_impl but needs to be defined outside of it to fix a Windows compilation issue
class llama_params_fit_exception : public std::runtime_error {
using std::runtime_error::runtime_error;
};
static void llama_params_fit_impl(
const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
@ -181,12 +185,11 @@ static void llama_params_fit_impl(
}
}
int64_t sum_total = 0;
int64_t sum_free = 0;
int64_t sum_projected_free = 0;
int64_t min_projected_free = INT64_MAX;
int64_t sum_projected_used = 0;
int64_t sum_projected_model = 0;
int64_t sum_projected_ctx = 0;
if (nd > 1) {
LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__);
@ -197,12 +200,11 @@ static void llama_params_fit_impl(
const int64_t projected_used = dmd.mb.total();
const int64_t projected_free = dmd.free - projected_used;
sum_total += dmd.total;
sum_free += dmd.free;
sum_projected_used += projected_used;
sum_projected_free += projected_free;
min_projected_free = std::min(min_projected_free, projected_free);
sum_projected_model += dmd.mb.model;
sum_projected_ctx += dmd.mb.context;
if (nd > 1) {
LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n",
@ -210,10 +212,9 @@ static void llama_params_fit_impl(
projected_free >= 0 ? "surplus" : "deficit");
}
}
assert(sum_total >= 0 && sum_projected_used >= 0 && sum_projected_ctx >= 0);
assert(sum_projected_used >= sum_projected_ctx);
assert(sum_free >= 0 && sum_projected_used >= 0);
LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n",
__func__, sum_projected_used/MiB, sum_total/MiB);
__func__, sum_projected_used/MiB, sum_free/MiB);
if (min_projected_free >= margin) {
if (nd == 1) {
LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n",
@ -236,9 +237,7 @@ static void llama_params_fit_impl(
__func__, margin/MiB, -global_surplus/MiB);
if (cparams->n_ctx == 0) {
if (hp_nct > n_ctx_min) {
const int64_t bytes_per_ctx = sum_projected_ctx / hp_nct;
int64_t memory_reduction = -global_surplus;
int64_t sum_used_target = sum_free - nd*margin_s;
if (nd > 1) {
// for multiple devices we need to be more conservative in terms of how much context we think can fit:
// - for dense models only whole layers can be assigned to devices
@ -246,24 +245,34 @@ static void llama_params_fit_impl(
// - on average we expect a waste of 0.5 layers/tensors per device
// - use slightly more than the expected average for nd devices to be safe
const int64_t model_per_layer = sum_projected_model / std::min(uint32_t(mparams->n_gpu_layers), hp_ngl);
memory_reduction += (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6);
sum_used_target -= (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6);
}
uint32_t ctx_reduction = std::min(uint32_t((memory_reduction + bytes_per_ctx - 1) / bytes_per_ctx), hp_nct - n_ctx_min);
cparams->n_ctx = hp_nct - ctx_reduction;
cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend
int64_t sum_projected_used_min_ctx = 0;
cparams->n_ctx = n_ctx_min;
const dmds_t dmds_min_ctx = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
for (const auto & dmd : dmds_min_ctx) {
sum_projected_used_min_ctx += dmd.mb.total();
}
if (sum_used_target > sum_projected_used_min_ctx) {
// linear interpolation between minimum and maximum context size:
cparams->n_ctx += (hp_nct - n_ctx_min) * (sum_used_target - sum_projected_used_min_ctx)
/ (sum_projected_used - sum_projected_used_min_ctx);
cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend
ctx_reduction = hp_nct - cparams->n_ctx;
memory_reduction = ctx_reduction * bytes_per_ctx;
global_surplus += memory_reduction;
LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
__func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
if (global_surplus >= 0) {
const int64_t bytes_per_ctx = (sum_projected_used - sum_projected_used_min_ctx) / (hp_nct - n_ctx_min);
const int64_t memory_reduction = (hp_nct - cparams->n_ctx) * bytes_per_ctx;
LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
__func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
if (nd == 1) {
LLAMA_LOG_INFO("%s: entire model can be fit by reducing context\n", __func__);
return;
}
LLAMA_LOG_INFO("%s: entire model should be fit across devices by reducing context\n", __func__);
} else {
const int64_t memory_reduction = sum_projected_used - sum_projected_used_min_ctx;
LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
__func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
}
} else {
LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n",
@ -276,28 +285,28 @@ static void llama_params_fit_impl(
}
if (mparams->n_gpu_layers != default_mparams.n_gpu_layers) {
throw std::runtime_error("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort");
throw llama_params_fit_exception("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort");
}
if (nd > 1) {
if (!tensor_split) {
throw std::runtime_error("did not provide a buffer to write the tensor_split to, abort");
throw llama_params_fit_exception("did not provide a buffer to write the tensor_split to, abort");
}
if (mparams->tensor_split) {
for (size_t id = 0; id < nd; id++) {
if (mparams->tensor_split[id] != 0.0f) {
throw std::runtime_error("model_params::tensor_split already set by user, abort");
throw llama_params_fit_exception("model_params::tensor_split already set by user, abort");
}
}
}
if (mparams->split_mode == LLAMA_SPLIT_MODE_ROW) {
throw std::runtime_error("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort");
throw llama_params_fit_exception("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort");
}
}
if (!tensor_buft_overrides) {
throw std::runtime_error("did not provide buffer to set tensor_buft_overrides, abort");
throw llama_params_fit_exception("did not provide buffer to set tensor_buft_overrides, abort");
}
if (mparams->tensor_buft_overrides && (mparams->tensor_buft_overrides->pattern || mparams->tensor_buft_overrides->buft)) {
throw std::runtime_error("model_params::tensor_buft_overrides already set by user, abort");
throw llama_params_fit_exception("model_params::tensor_buft_overrides already set by user, abort");
}
// step 3: iteratively fill the back to front with "dense" layers
@ -380,8 +389,8 @@ static void llama_params_fit_impl(
tensor_buft_overrides[itbo].buft = nullptr;
itbo++;
mparams.tensor_buft_overrides = tensor_buft_overrides;
throw std::runtime_error("llama_params_fit_n_tensor_buft_overrides() == "
+ std::to_string(ntbo) + " is insufficient for model\n");
throw llama_params_fit_exception("llama_max_tensor_buft_overrides() == "
+ std::to_string(ntbo) + " is insufficient for model");
}
tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE);
tensor_buft_overrides[itbo].buft = overflow_bufts[id];
@ -503,6 +512,9 @@ static void llama_params_fit_impl(
if (mem_high[id] > targets[id]) {
assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer);
uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
if (hp_nex > 0 && size_t(id) == nd - 1) {
delta--;
}
LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta);
while (delta > 1) {
uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
@ -638,7 +650,7 @@ static void llama_params_fit_impl(
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP;
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__);
std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
if (mem_test[id] < targets[id]) {
if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
ngl_per_device = ngl_per_device_test;
mem = mem_test;
id_dense_start = id_dense_start_test;
@ -648,7 +660,7 @@ static void llama_params_fit_impl(
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE;
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__);
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
if (mem_test[id] < targets[id]) {
if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
ngl_per_device = ngl_per_device_test;
mem = mem_test;
id_dense_start = id_dense_start_test;
@ -659,7 +671,7 @@ static void llama_params_fit_impl(
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN;
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__);
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
if (mem_test[id] < targets[id]) {
if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
ngl_per_device = ngl_per_device_test;
mem = mem_test;
id_dense_start = id_dense_start_test;
@ -678,22 +690,25 @@ static void llama_params_fit_impl(
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams);
}
bool llama_params_fit(
enum llama_params_fit_status llama_params_fit(
const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) {
const int64_t t0_us = llama_time_us();
bool ok = true;
llama_params_fit_status status = LLAMA_PARAMS_FIT_STATUS_SUCCESS;
try {
llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margin_s, n_ctx_min, log_level);
LLAMA_LOG_INFO("%s: successfully fit params to free device memory\n", __func__);
} catch (const std::runtime_error & e) {
} catch (const llama_params_fit_exception & e) {
LLAMA_LOG_WARN("%s: failed to fit params to free device memory: %s\n", __func__, e.what());
ok = false;
status = LLAMA_PARAMS_FIT_STATUS_FAILURE;
} catch (const std::runtime_error & e) {
LLAMA_LOG_ERROR("%s: encountered an error while trying to fit params to free device memory: %s\n", __func__, e.what());
status = LLAMA_PARAMS_FIT_STATUS_ERROR;
}
const int64_t t1_us = llama_time_us();
LLAMA_LOG_INFO("%s: fitting params to free memory took %.2f seconds\n", __func__, (t1_us - t0_us) * 1e-6);
return ok;
return status;
}
struct llama_sampler_chain_params llama_sampler_chain_default_params() {

View File

@ -1,6 +1,7 @@
#include "models.h"
llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
template <bool embed>
llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@ -14,7 +15,14 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_para
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv();
using inp_attn_type = std::conditional_t<embed, llm_graph_input_attn_no_cache, llm_graph_input_attn_kv>;
inp_attn_type * inp_attn = nullptr;
if constexpr (embed) {
inp_attn = build_attn_inp_no_cache();
} else {
inp_attn = build_attn_inp_kv();
}
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
@ -145,11 +153,16 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_para
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
if constexpr (!embed) {
// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
cb(cur, "result_output", -1);
res->t_logits = cur;
}
ggml_build_forward_expand(gf, cur);
}
template struct llm_build_llama<false>;
template struct llm_build_llama<true>;

123
src/models/mimo2-iswa.cpp Normal file
View File

@ -0,0 +1,123 @@
#include "models.h"
llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_iswa();
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
uint32_t n_head_l = hparams.n_head(il);
uint32_t n_head_kv_l = hparams.n_head_kv(il);
const float freq_base_l = model.get_rope_freq_base(cparams, il);
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
cur = inpL;
// self_attention
{
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
ggml_tensor * sinks = model.layers[il].attn_sinks;
cur = build_attn(inp_attn,
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, sinks, nullptr, 1.0f/sqrtf(float(n_embd_head_k)), il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
// feed-forward network
if (model.layers[il].ffn_gate_inp == nullptr) {
// dense branch
cur = build_ffn(cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
} else {
// MoE branch
cur = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false,
0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, il);
cb(cur, "ffn_moe_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}

View File

@ -303,6 +303,7 @@ struct llm_build_llada_moe : public llm_graph_context {
llm_build_llada_moe(const llama_model & model, const llm_graph_params & params);
};
template <bool embed>
struct llm_build_llama : public llm_graph_context {
llm_build_llama(const llama_model & model, const llm_graph_params & params);
};
@ -315,6 +316,10 @@ struct llm_build_mamba : public llm_graph_context_mamba {
llm_build_mamba(const llama_model & model, const llm_graph_params & params);
};
struct llm_build_mimo2_iswa : public llm_graph_context {
llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params);
};
struct llm_build_minicpm3 : public llm_graph_context {
llm_build_minicpm3(const llama_model & model, const llm_graph_params & params);
};

View File

@ -402,12 +402,20 @@ static std::string var_to_str(ggml_op_pool pool) {
}
static std::string var_to_str(ggml_scale_mode mode) {
switch (mode) {
case GGML_SCALE_MODE_NEAREST: return "nearest";
case GGML_SCALE_MODE_BILINEAR: return "bilinear";
case GGML_SCALE_MODE_BICUBIC: return "bicubic";
default: return std::to_string(mode);
std::string str;
switch (mode & 0xFF) {
case GGML_SCALE_MODE_NEAREST: str = "nearest"; break;
case GGML_SCALE_MODE_BILINEAR: str = "bilinear"; break;
case GGML_SCALE_MODE_BICUBIC: str = "bicubic"; break;
default: str = std::to_string(mode); break;
}
if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
str += "|align_corners";
}
if (mode & GGML_SCALE_FLAG_ANTIALIAS) {
str += "|antialias";
}
return str;
}
#define VAR_TO_STR(x) (#x "=" + var_to_str(x))
@ -5535,18 +5543,16 @@ struct test_interpolate : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const std::array<int64_t, 4> ne_tgt;
const uint32_t mode = GGML_SCALE_MODE_NEAREST;
const ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST;
std::string vars() override {
ggml_scale_mode mode = (ggml_scale_mode)(this->mode & 0xFF);
std::string flags = (this->mode & GGML_SCALE_FLAG_ALIGN_CORNERS) ? "align_corners" : "none";
return VARS_TO_STR5(type, ne, ne_tgt, mode, flags);
return VARS_TO_STR4(type, ne, ne_tgt, mode);
}
test_interpolate(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {2, 5, 7, 11},
std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13},
uint32_t mode = GGML_SCALE_MODE_NEAREST)
ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST)
: type(type), ne(ne), ne_tgt(ne_tgt), mode(mode) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
@ -7775,6 +7781,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 13B
test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 30B
test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 65B
test_cases.emplace_back(new test_rope(type, {16, 16, 8192, 1}, 16, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));
}
if (all) {
@ -7789,6 +7796,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
test_cases.emplace_back(new test_rope(type, { 16, 16, 8192, 1}, 16, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw));
}
if (all) {
@ -7802,6 +7810,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw));
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
test_cases.emplace_back(new test_rope(type, {128, 16, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen3vl)
test_cases.emplace_back(new test_rope(type, {16, 16, 8192, 1}, 16, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw));
}
test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
@ -7880,9 +7889,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {5, 7, 11, 13}, {2, 5, 7, 11}, mode));
}
for (ggml_scale_mode mode : {GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) {
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode | GGML_SCALE_FLAG_ALIGN_CORNERS));
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {1, 4, 3, 2}, {2, 8, 3, 2}, mode | GGML_SCALE_FLAG_ALIGN_CORNERS));
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {4, 1, 3, 2}, {1, 1, 3, 2}, mode | GGML_SCALE_FLAG_ALIGN_CORNERS));
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, (ggml_scale_mode)(mode | GGML_SCALE_FLAG_ALIGN_CORNERS)));
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {1, 4, 3, 2}, {2, 8, 3, 2}, (ggml_scale_mode)(mode | GGML_SCALE_FLAG_ALIGN_CORNERS)));
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {4, 1, 3, 2}, {1, 1, 3, 2}, (ggml_scale_mode)(mode | GGML_SCALE_FLAG_ALIGN_CORNERS)));
}
test_cases.emplace_back(new test_sum());

View File

@ -26,16 +26,16 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa);
auto mparams = common_model_params_to_llama(params);
auto cparams = common_context_params_to_llama(params);
const bool success = llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
const llama_params_fit_status status = llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
if (!success) {
if (status != LLAMA_PARAMS_FIT_STATUS_SUCCESS) {
LOG_ERR("%s: failed to fit CLI arguments to free memory, exiting...\n", __func__);
exit(1);
}
LOG_INF("%s: printing fitted CLI arguments to stdout...\n", __func__);
std::this_thread::sleep_for(10ms); // to avoid a race between stderr and stdout
common_log_flush(common_log_main());
printf("-c %" PRIu32 " -ngl %" PRIu32, cparams.n_ctx, mparams.n_gpu_layers);
size_t nd = llama_max_devices();

View File

@ -2,6 +2,11 @@
#include "../clip-graph.h"
/*
* IMPORTANT: The mtmd module does NOT accept pull requests that are fully or predominantly AI-generated.
* We encourage human contributors to ensure the quality and reliability of the codebase.
*/
struct clip_graph_siglip : clip_graph {
clip_graph_siglip(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;

View File

@ -27,6 +27,9 @@
* - Make sure the C API is aligned with the libllama C API (as in llama.h)
* - Do not include model name (e.g., qwen, gemma) in the API, use generic terms instead
* - Keep the API minimal, do not expose internal details unless necessary
*
* IMPORTANT: The mtmd module does NOT accept pull requests that are fully or predominantly AI-generated.
* We encourage human contributors to ensure the quality and reliability of the codebase.
*/
#ifdef LLAMA_SHARED

View File

@ -1486,6 +1486,7 @@ The precedence rule for preset options is as follows:
We also offer additional options that are exclusive to presets (these aren't treated as command-line arguments):
- `load-on-startup` (boolean): Controls whether the model loads automatically when the server starts
- `stop-timeout` (int, seconds): After requested unload, wait for this many seconds before forcing termination (default: 10)
### Routing requests
@ -1574,8 +1575,7 @@ Payload:
```json
{
"model": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M",
"extra_args": ["-n", "128", "--top-k", "4"]
"model": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M"
}
```

View File

@ -1007,8 +1007,10 @@ private:
return ret;
}
void clear_slot(server_slot & slot) const {
GGML_ASSERT(!slot.is_processing());
void clear_slot(server_slot & slot, bool allow_processing = false) const {
if (!allow_processing) {
GGML_ASSERT(!slot.is_processing());
}
SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size());
@ -2313,6 +2315,12 @@ private:
slot.n_prompt_tokens_processed = 0;
slot.prompt.tokens.keep_first(n_past);
// send initial 0% progress update if needed
// this is to signal the client that the request has started processing
if (slot.task->params.stream && slot.task->params.return_progress) {
send_partial_response(slot, {}, true);
}
}
if (!slot.can_split()) {
@ -2330,7 +2338,7 @@ private:
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
clear_slot(slot);
clear_slot(slot, /*allow_processing=*/true);
// there is no common part left
slot.n_prompt_tokens_cache = 0;

View File

@ -34,6 +34,8 @@
#include <limits.h>
#endif
#define DEFAULT_STOP_TIMEOUT 10 // seconds
#define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit"
#define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready"
@ -203,13 +205,14 @@ void server_models::load_models() {
// convert presets to server_model_meta and add to mapping
for (const auto & preset : final_presets) {
server_model_meta meta{
/* preset */ preset.second,
/* name */ preset.first,
/* port */ 0,
/* status */ SERVER_MODEL_STATUS_UNLOADED,
/* last_used */ 0,
/* args */ std::vector<std::string>(),
/* exit_code */ 0
/* preset */ preset.second,
/* name */ preset.first,
/* port */ 0,
/* status */ SERVER_MODEL_STATUS_UNLOADED,
/* last_used */ 0,
/* args */ std::vector<std::string>(),
/* exit_code */ 0,
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
};
add_model(std::move(meta));
}
@ -227,6 +230,20 @@ void server_models::load_models() {
}
}
// handle custom stop-timeout option
for (auto & [name, inst] : mapping) {
std::string val;
if (inst.meta.preset.get_option(COMMON_ARG_PRESET_STOP_TIMEOUT, val)) {
try {
inst.meta.stop_timeout = std::stoi(val);
} catch (...) {
SRV_WRN("invalid stop-timeout value '%s' for model '%s', using default %d seconds\n",
val.c_str(), name.c_str(), DEFAULT_STOP_TIMEOUT);
inst.meta.stop_timeout = DEFAULT_STOP_TIMEOUT;
}
}
}
// load any autoload models
std::vector<std::string> models_to_load;
for (const auto & [name, inst] : mapping) {
@ -362,7 +379,7 @@ void server_models::unload_lru() {
int64_t lru_last_used = ggml_time_ms();
size_t count_active = 0;
{
std::lock_guard<std::mutex> lk(mutex);
std::unique_lock<std::mutex> lk(mutex);
for (const auto & m : mapping) {
if (m.second.meta.is_active()) {
count_active++;
@ -376,6 +393,13 @@ void server_models::unload_lru() {
if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) {
SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str());
unload(lru_model_name);
// wait for unload to complete
{
std::unique_lock<std::mutex> lk(mutex);
cv.wait(lk, [this, &lru_model_name]() {
return mapping[lru_model_name].meta.status == SERVER_MODEL_STATUS_UNLOADED;
});
}
}
}
@ -436,38 +460,83 @@ void server_models::load(const std::string & name) {
// start a thread to manage the child process
// captured variables are guaranteed to be destroyed only after the thread is joined
inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port]() {
// read stdout/stderr and forward to main server log
bool state_received = false; // true if child state received
FILE * p_stdout_stderr = subprocess_stdout(child_proc.get());
if (p_stdout_stderr) {
char buffer[4096];
while (fgets(buffer, sizeof(buffer), p_stdout_stderr) != nullptr) {
LOG("[%5d] %s", port, buffer);
if (!state_received && std::strstr(buffer, CMD_CHILD_TO_ROUTER_READY) != nullptr) {
// child process is ready
this->update_status(name, SERVER_MODEL_STATUS_LOADED);
state_received = true;
inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port, stop_timeout = inst.meta.stop_timeout]() {
FILE * stdin_file = subprocess_stdin(child_proc.get());
FILE * stdout_file = subprocess_stdout(child_proc.get()); // combined stdout/stderr
std::thread log_thread([&]() {
// read stdout/stderr and forward to main server log
// also handle status report from child process
bool state_received = false; // true if child state received
if (stdout_file) {
char buffer[4096];
while (fgets(buffer, sizeof(buffer), stdout_file) != nullptr) {
LOG("[%5d] %s", port, buffer);
if (!state_received && std::strstr(buffer, CMD_CHILD_TO_ROUTER_READY) != nullptr) {
// child process is ready
this->update_status(name, SERVER_MODEL_STATUS_LOADED, 0);
state_received = true;
}
}
} else {
SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str());
}
} else {
SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str());
}
});
std::thread stopping_thread([&]() {
// thread to monitor stopping signal
auto is_stopping = [this, &name]() {
return this->stopping_models.find(name) != this->stopping_models.end();
};
{
std::unique_lock<std::mutex> lk(this->mutex);
this->cv_stop.wait(lk, is_stopping);
}
SRV_INF("stopping model instance name=%s\n", name.c_str());
// send interrupt to child process
fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
fflush(stdin_file);
// wait to stop gracefully or timeout
int64_t start_time = ggml_time_ms();
while (true) {
std::unique_lock<std::mutex> lk(this->mutex);
if (!is_stopping()) {
return; // already stopped
}
int64_t elapsed = ggml_time_ms() - start_time;
if (elapsed >= stop_timeout * 1000) {
// timeout, force kill
SRV_WRN("force-killing model instance name=%s after %d seconds timeout\n", name.c_str(), stop_timeout);
subprocess_terminate(child_proc.get());
return;
}
this->cv_stop.wait_for(lk, std::chrono::seconds(1));
}
});
// we reach here when the child process exits
// note: we cannot join() prior to this point because it will close stdin_file
if (log_thread.joinable()) {
log_thread.join();
}
// stop the timeout monitoring thread
{
std::lock_guard<std::mutex> lk(this->mutex);
stopping_models.erase(name);
cv_stop.notify_all();
}
if (stopping_thread.joinable()) {
stopping_thread.join();
}
// get the exit code
int exit_code = 0;
subprocess_join(child_proc.get(), &exit_code);
subprocess_destroy(child_proc.get());
// update PID and status
{
std::lock_guard<std::mutex> lk(mutex);
auto it = mapping.find(name);
if (it != mapping.end()) {
auto & meta = it->second.meta;
meta.exit_code = exit_code;
meta.status = SERVER_MODEL_STATUS_UNLOADED;
}
cv.notify_all();
}
// update status and exit code
this->update_status(name, SERVER_MODEL_STATUS_UNLOADED, exit_code);
SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
});
@ -488,22 +557,14 @@ void server_models::load(const std::string & name) {
cv.notify_all();
}
static void interrupt_subprocess(FILE * stdin_file) {
// because subprocess.h does not provide a way to send SIGINT,
// we will send a command to the child process to exit gracefully
if (stdin_file) {
fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
fflush(stdin_file);
}
}
void server_models::unload(const std::string & name) {
std::lock_guard<std::mutex> lk(mutex);
auto it = mapping.find(name);
if (it != mapping.end()) {
if (it->second.meta.is_active()) {
SRV_INF("unloading model instance name=%s\n", name.c_str());
interrupt_subprocess(it->second.stdin_file);
stopping_models.insert(name);
cv_stop.notify_all();
// status change will be handled by the managing thread
} else {
SRV_WRN("model instance name=%s is not loaded\n", name.c_str());
@ -518,7 +579,8 @@ void server_models::unload_all() {
for (auto & [name, inst] : mapping) {
if (inst.meta.is_active()) {
SRV_INF("unloading model instance name=%s\n", name.c_str());
interrupt_subprocess(inst.stdin_file);
stopping_models.insert(name);
cv_stop.notify_all();
// status change will be handled by the managing thread
}
// moving the thread to join list to avoid deadlock
@ -532,16 +594,15 @@ void server_models::unload_all() {
}
}
void server_models::update_status(const std::string & name, server_model_status status) {
// for now, we only allow updating to LOADED status
if (status != SERVER_MODEL_STATUS_LOADED) {
throw std::runtime_error("invalid status value");
}
auto meta = get_meta(name);
if (meta.has_value()) {
meta->status = status;
update_meta(name, meta.value());
void server_models::update_status(const std::string & name, server_model_status status, int exit_code) {
std::unique_lock<std::mutex> lk(mutex);
auto it = mapping.find(name);
if (it != mapping.end()) {
auto & meta = it->second.meta;
meta.status = status;
meta.exit_code = exit_code;
}
cv.notify_all();
}
void server_models::wait_until_loaded(const std::string & name) {
@ -568,6 +629,7 @@ bool server_models::ensure_model_loaded(const std::string & name) {
load(name);
}
// for loading state
SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str());
wait_until_loaded(name);
@ -795,7 +857,7 @@ void server_models_routes::init_routes() {
res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
return res;
}
if (model->status != SERVER_MODEL_STATUS_LOADED) {
if (!model->is_active()) {
res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
return res;
}

View File

@ -9,6 +9,7 @@
#include <condition_variable>
#include <functional>
#include <memory>
#include <set>
/**
* state diagram:
@ -56,6 +57,7 @@ struct server_model_meta {
int64_t last_used = 0; // for LRU unloading
std::vector<std::string> args; // args passed to the model instance, will be populated by render_args()
int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown
bool is_active() const {
return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING;
@ -83,6 +85,10 @@ private:
std::condition_variable cv;
std::map<std::string, instance_t> mapping;
// for stopping models
std::condition_variable cv_stop;
std::set<std::string> stopping_models;
common_preset_context ctx_preset;
common_params base_params;
@ -119,7 +125,7 @@ public:
void unload_all();
// update the status of a model instance (thread-safe)
void update_status(const std::string & name, server_model_status status);
void update_status(const std::string & name, server_model_status status, int exit_code);
// wait until the model instance is fully loaded (thread-safe)
// return when the model is loaded or failed to load

View File

@ -434,8 +434,8 @@ def test_context_size_exceeded_stream():
@pytest.mark.parametrize(
"n_batch,batch_count,reuse_cache",
[
(64, 3, False),
(64, 1, True),
(64, 4, False),
(64, 2, True),
]
)
def test_return_progress(n_batch, batch_count, reuse_cache):
@ -462,10 +462,18 @@ def test_return_progress(n_batch, batch_count, reuse_cache):
res = make_cmpl_request()
last_progress = None
total_batch_count = 0
for data in res:
cur_progress = data.get("prompt_progress", None)
if cur_progress is None:
continue
if total_batch_count == 0:
# first progress report must have n_cache == n_processed
assert cur_progress["total"] > 0
assert cur_progress["cache"] == cur_progress["processed"]
if reuse_cache:
# when reusing cache, we expect some cached tokens
assert cur_progress["cache"] > 0
if last_progress is not None:
assert cur_progress["total"] == last_progress["total"]
assert cur_progress["cache"] == last_progress["cache"]
@ -473,6 +481,7 @@ def test_return_progress(n_batch, batch_count, reuse_cache):
total_batch_count += 1
last_progress = cur_progress
# last progress should indicate completion (all tokens processed)
assert last_progress is not None
assert last_progress["total"] > 0
assert last_progress["processed"] == last_progress["total"]