metal : add FLOOR, CEIL, ROUND, TRUNC unary ops (#20930)

Co-authored-by: nryoo <nryoo@nryooui-MacBookPro.local>
This commit is contained in:
nuri 2026-03-24 17:13:07 +09:00 committed by GitHub
parent 342d6125bc
commit 92080b4396
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 28 additions and 0 deletions

View File

@ -246,6 +246,10 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal
case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break;
case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break;
case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break;
case GGML_UNARY_OP_FLOOR: op_num = OP_UNARY_NUM_FLOOR; break;
case GGML_UNARY_OP_CEIL: op_num = OP_UNARY_NUM_CEIL; break;
case GGML_UNARY_OP_ROUND: op_num = OP_UNARY_NUM_ROUND; break;
case GGML_UNARY_OP_TRUNC: op_num = OP_UNARY_NUM_TRUNC; break;
default: GGML_ABORT("fatal error");
} break;
default: GGML_ABORT("fatal error");

View File

@ -1039,6 +1039,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_EXPM1:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_TRUNC:
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
default:
return false;

View File

@ -120,6 +120,10 @@
#define OP_UNARY_NUM_EXP 114
#define OP_UNARY_NUM_SOFTPLUS 115
#define OP_UNARY_NUM_EXPM1 116
#define OP_UNARY_NUM_FLOOR 117
#define OP_UNARY_NUM_CEIL 118
#define OP_UNARY_NUM_ROUND 119
#define OP_UNARY_NUM_TRUNC 120
#define OP_SUM_ROWS_NUM_SUM_ROWS 10
#define OP_SUM_ROWS_NUM_MEAN 11

View File

@ -1094,6 +1094,22 @@ kernel void kernel_unary_impl(
// TODO: precise implementation
dst_ptr[i0] = (T) (exp(x) - 1);
}
if (FC_OP == OP_UNARY_NUM_FLOOR) {
dst_ptr[i0] = (T) floor(x);
}
if (FC_OP == OP_UNARY_NUM_CEIL) {
dst_ptr[i0] = (T) ceil(x);
}
if (FC_OP == OP_UNARY_NUM_ROUND) {
dst_ptr[i0] = (T) round(x);
}
if (FC_OP == OP_UNARY_NUM_TRUNC) {
dst_ptr[i0] = (T) trunc(x);
}
}
#undef FC_OP