diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 72ad876d5e..9162342ee9 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1748,6 +1748,28 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_met return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_CONV_3D); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(op->type == GGML_TYPE_F32); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_conv_3d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_UPSCALE); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index fd2b3ddeb5..de43f81931 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -148,6 +148,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 82101f4714..14144aab08 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1077,6 +1077,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_CONV_3D: + return ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]) && + (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) && + op->src[1]->type == GGML_TYPE_F32; case GGML_OP_SUM: return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_TRI: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 53437b23cd..ea471090cd 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -643,6 +643,42 @@ typedef struct { int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources } ggml_metal_kargs_im2col; +typedef struct { + int32_t IW; + int32_t IH; + int32_t ID; + int32_t OW; + int32_t OH; + int32_t OD; + int32_t KW; + int32_t KH; + int32_t KD; + int32_t s0; + int32_t s1; + int32_t s2; + int32_t p0; + int32_t p1; + int32_t p2; + int32_t d0; + int32_t d1; + int32_t d2; + int32_t IC; + int32_t N; + int32_t OC; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_conv_3d; + typedef struct{ int32_t ne00; uint64_t nb01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index c0bcad392b..3cda21be43 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -394,6 +394,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx); } break; + case GGML_OP_CONV_3D: + { + n_fuse = ggml_metal_op_conv_3d(ctx, idx); + } break; case GGML_OP_UPSCALE: { n_fuse = ggml_metal_op_upscale(ctx, idx); @@ -3697,6 +3701,77 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_conv_3d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + // 1. Extract standard dimensions and byte strides + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + // 2. Extract hyperparams from op_params + const int32_t s0 = ((const int32_t *)(op->op_params))[0]; + const int32_t s1 = ((const int32_t *)(op->op_params))[1]; + const int32_t s2 = ((const int32_t *)(op->op_params))[2]; + const int32_t p0 = ((const int32_t *)(op->op_params))[3]; + const int32_t p1 = ((const int32_t *)(op->op_params))[4]; + const int32_t p2 = ((const int32_t *)(op->op_params))[5]; + const int32_t d0 = ((const int32_t *)(op->op_params))[6]; + const int32_t d1 = ((const int32_t *)(op->op_params))[7]; + const int32_t d2 = ((const int32_t *)(op->op_params))[8]; + const int32_t IC = ((const int32_t *)(op->op_params))[9]; + const int32_t N = ((const int32_t *)(op->op_params))[10]; + const int32_t OC = ((const int32_t *)(op->op_params))[11]; + + // 3. Build the parameter struct using the macro-generated variables + ggml_metal_kargs_conv_3d args = { + /*.IW =*/ (int32_t)op->src[1]->ne[0], + /*.IH =*/ (int32_t)op->src[1]->ne[1], + /*.ID =*/ (int32_t)op->src[1]->ne[2], + /*.OW =*/ (int32_t)op->ne[0], + /*.OH =*/ (int32_t)op->ne[1], + /*.OD =*/ (int32_t)op->ne[2], + /*.KW =*/ (int32_t)op->src[0]->ne[0], + /*.KH =*/ (int32_t)op->src[0]->ne[1], + /*.KD =*/ (int32_t)op->src[0]->ne[2], + s0, s1, s2, + p0, p1, p2, + d0, d1, d2, + IC, N, OC, + nb00, nb01, nb02, nb03, // Weight strides + nb10, nb11, nb12, nb13, // Input strides + nb0, nb1, nb2, nb3 // Output strides + }; + + // 4. Fetch the JIT pipeline + auto pipeline = ggml_metal_library_get_pipeline_conv_3d(lib, op); + + // 5. Grid mapping + int nth0 = 32; // Standard SIMD width for Apple Silicon + int nth1 = 1; + int nth2 = 1; + + int64_t spatial_volume = args.OW * args.OH * args.OD; + + int ntg0 = (spatial_volume + nth0 - 1) / nth0; + int ntg1 = args.OC; + int ntg2 = args.N; + + // 6. Bind and Dispatch via the ggml C wrapper + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, ntg0, ntg1, ntg2, nth0, nth1, nth2); + + return 1; +} + int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 019f2fec9e..50e3c5c77a 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -75,6 +75,7 @@ int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx); int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_2d (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_conv_3d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index b2328605dd..9c6b1c4f62 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4883,6 +4883,98 @@ kernel void kernel_upscale_bilinear_f32( } } +template +kernel void kernel_conv_3d( + constant ggml_metal_kargs_conv_3d & args, + device const char * src0, // Weights [IC * OC, KD, KH, KW] + device const char * src1, // Inputs [IC * N, ID, IH, IW] + device char * dst, // Outputs [OC * N, OD, OH, OW] + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]) { + + // 1. Un-flatten the spatial dimension from Grid X + int64_t spatial_idx = tgpig.x * 32 + tpitg.x; + + if (spatial_idx >= args.OW * args.OH * args.OD) { + return; // Thread falls outside the spatial volume + } + + int64_t od = spatial_idx / (args.OW * args.OH); + int64_t oh = (spatial_idx / args.OW) % args.OH; + int64_t ow = spatial_idx % args.OW; + + // 2. Map Y to Channels, Z to Batch + int64_t oc = tgpig.y; + int64_t batch_idx = tgpig.z; + + // 3. Calculate anchor coordinates in the Input volume + int64_t i_w_base = ow * args.s0 - args.p0; + int64_t i_h_base = oh * args.s1 - args.p1; + int64_t i_d_base = od * args.s2 - args.p2; + + float sum = 0.0f; + + // 4. Gather Loop (Iterate over Input Channels -> Depth -> Height -> Width) + for (int64_t ic = 0; ic < args.IC; ++ic) { + + // ggml packs batch and channel together in the 4th dimension + int64_t src_cn_idx = batch_idx * args.IC + ic; + int64_t w_cn_idx = oc * args.IC + ic; + + for (int64_t kz = 0; kz < args.KD; ++kz) { + int64_t id = i_d_base + kz * args.d2; + if (id < 0 || id >= args.ID) continue; // Boundary check (Padding) + + for (int64_t ky = 0; ky < args.KH; ++ky) { + int64_t ih = i_h_base + ky * args.d1; + if (ih < 0 || ih >= args.IH) continue; + + for (int64_t kx = 0; kx < args.KW; ++kx) { + int64_t iw = i_w_base + kx * args.d0; + if (iw < 0 || iw >= args.IW) continue; + + // Convert multi-dimensional coordinates to flat byte offsets + int64_t w_idx = kx*args.nb00 + ky*args.nb01 + kz*args.nb02 + w_cn_idx*args.nb03; + int64_t i_idx = iw*args.nb10 + ih*args.nb11 + id*args.nb12 + src_cn_idx*args.nb13; + + // Dereference memory and cast weights to f32 if they were f16 + float w_val = (float)*(device const T*)((device const char*)src0 + w_idx); + float i_val = *(device const float*)((device const char*)src1 + i_idx); + + sum += w_val * i_val; + } + } + } + } + + // 5. Write the accumulated value out to RAM + int64_t dst_cn_idx = batch_idx * args.OC + oc; + int64_t d_idx = ow*args.nb0 + oh*args.nb1 + od*args.nb2 + dst_cn_idx*args.nb3; + + *(device float*)(dst + d_idx) = sum; +} + +// Explicit instantiations so the JIT compiler can find them by name +template [[host_name("kernel_conv_3d_f32_f32")]] +kernel void kernel_conv_3d( + constant ggml_metal_kargs_conv_3d & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + +// Explicit instantiation for f16 weights +template [[host_name("kernel_conv_3d_f16_f32")]] +kernel void kernel_conv_3d( + constant ggml_metal_kargs_conv_3d & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + + static inline float bicubic_weight1(float x) { const float a = -0.75f; return ((a + 2) * x - (a + 3)) * x * x + 1;