metal: use mul_mv_ext for large n on non-simdgroup_mm GPUs

On GPUs without simdgroup_mm (e.g. AMD discrete), MUL_MAT with large n
(like pp512) falls through to the per-column mul_mv kernel, which
dispatches ~1.1M threadgroups vs ~280 for the multi-column mul_mv_ext.

Remove the ne11 <= 8 upper bound for non-simdgroup_mm devices so
mul_mv_ext handles all n values. Default r1ptg to 4 for ne11 > 8
instead of aborting.

Benchmarked on AMD Radeon Pro 5300M (Qwen2.5-1.5B-Q4_K_M):
- pp512: 103.18 -> 127.19 t/s (+23.3%)
- tg128: no regression (64.01 t/s)
- test-backend-ops MUL_MAT: 1009/1009 passed
This commit is contained in:
hung 2026-02-13 12:26:49 -05:00
parent b48e80f677
commit 030b09faa8
1 changed files with 3 additions and 3 deletions

View File

@ -1964,14 +1964,14 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
op->src[0]->type == GGML_TYPE_Q8_0 ||
op->src[0]->type == GGML_TYPE_MXFP4 ||
op->src[0]->type == GGML_TYPE_IQ4_NL ||
false) && (ne11 >= 2 && ne11 <= 8)
false) && (ne11 >= 2 && (ne11 <= 8 || !props_dev->has_simdgroup_mm))
) ||
(
(
op->src[0]->type == GGML_TYPE_Q4_K ||
op->src[0]->type == GGML_TYPE_Q5_K ||
op->src[0]->type == GGML_TYPE_Q6_K ||
false) && (ne11 >= 4 && ne11 <= 8)
false) && (ne11 >= 4 && (ne11 <= 8 || !props_dev->has_simdgroup_mm))
)
)
) {
@ -2013,7 +2013,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
case 5:
r1ptg = 5; break;
default:
GGML_ABORT("unsupported ne11");
r1ptg = 4; break;
};
auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);