metal : remove broken implementation of GGML_OP_SET

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-09-09 14:29:54 +03:00
parent 7fc2b3d503
commit f288225d42
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 0 additions and 108 deletions

View File

@ -477,8 +477,6 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
GGML_METAL_KERNEL_TYPE_SET_I32,
GGML_METAL_KERNEL_TYPE_SET_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
@ -1411,8 +1409,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
@ -2012,16 +2008,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
return false;
};
}
case GGML_OP_SET:
{
switch (op->src[0]->type) {
case GGML_TYPE_F32:
case GGML_TYPE_I32:
return true;
default:
return false;
};
}
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_GET_ROWS:
{
@ -5597,68 +5583,6 @@ static int ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
} break;
case GGML_OP_SET:
{
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
// src0 and dst as viewed during set
const size_t dst_nb0 = ggml_element_size(src0);
const size_t dst_nb1 = ((int32_t *) dst->op_params)[0];
const size_t dst_nb2 = ((int32_t *) dst->op_params)[1];
const size_t dst_nb3 = ((int32_t *) dst->op_params)[2];
const size_t offset = ((int32_t *) dst->op_params)[3];
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
if (!inplace) {
memcpy(((char *) dst->data), ((char *) src0->data), ggml_nbytes(dst));
}
const int im0 = (ne10 == 0 ? 0 : ne10-1);
const int im1 = (ne11 == 0 ? 0 : ne11-1);
const int im2 = (ne12 == 0 ? 0 : ne12-1);
const int im3 = (ne13 == 0 ? 0 : ne13-1);
GGML_ASSERT(offset + im0*dst_nb0 + im1*dst_nb1 + im2*dst_nb2 + im3*dst_nb3 <= ggml_nbytes(dst));
id<MTLComputePipelineState> pipeline = nil;
switch (src0t) {
case GGML_TYPE_F32:
GGML_ASSERT(nb10 == sizeof(float));
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_F32].pipeline; break;
case GGML_TYPE_I32:
GGML_ASSERT(nb10 == sizeof(int32_t));
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_I32].pipeline; break;
default: GGML_ABORT("fatal error");
}
ggml_metal_kargs_set args = {
/*.ne10 =*/ ne10,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.nb1 =*/ dst_nb1,
/*.nb2 =*/ dst_nb2,
/*.nb3 =*/ dst_nb3,
/*.offs =*/ offset,
/*.inplace =*/ inplace,
};
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne10);
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_POOL_2D:
{
GGML_ASSERT(ggml_is_contiguous(src0));

View File

@ -5571,38 +5571,6 @@ kernel void kernel_flash_attn_ext_vec_reduce(
#undef DV
}
template<typename T>
kernel void kernel_set(
constant ggml_metal_kargs_set & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i13 = tgpig[2];
const int i12 = tgpig[1];
const int i11 = tgpig[0];
const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10;
const int64_t i3 = n / (args.ne12*args.ne11*args.ne10);
const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10);
const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10;
device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs);
for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) {
device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10);
dst_data[i10] = (T) src[0];
}
}
typedef decltype(kernel_set<float>) kernel_set_t;
template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set<float>;
template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set<int32_t>;
template<typename T0, typename T1>
kernel void kernel_cpy(
constant ggml_metal_kargs_cpy & args,