diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 06f3d80459..169c63dd7a 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1717,12 +1717,29 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_met char base[256]; char name[256]; - snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type)); - snprintf(name, 256, "%s", base); + const int32_t mode_flags = ggml_get_op_params_i32(op, 0); + const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF); + + const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS); + + if (mode == GGML_SCALE_MODE_BILINEAR) { + snprintf(base, 256, "kernel_upscale_bilinear_%s", ggml_type_name(op->src[0]->type)); + } else if (mode == GGML_SCALE_MODE_BICUBIC) { + snprintf(base, 256, "kernel_upscale_bicubic_%s", ggml_type_name(op->src[0]->type)); + } else { + snprintf(base, 256, "kernel_upscale_nearest_%s", ggml_type_name(op->src[0]->type)); + } + snprintf(name, 256, "%s_aa=%d", base, antialias); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, antialias, FC_UPSCALE + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } return res; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 4cce414abf..23bd2b2ab7 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1108,7 +1108,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te op->type == GGML_TYPE_F32 && (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_POOL_1D: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_POOL_2D: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 383e0d6e93..bf51055e36 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -83,6 +83,7 @@ #define FC_UNARY 1200 #define FC_BIN 1300 #define FC_SUM_ROWS 1400 +#define FC_UPSCALE 1500 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPSG 8 @@ -890,6 +891,7 @@ typedef struct { float sf1; float sf2; float sf3; + float poffs; } ggml_metal_kargs_upscale; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index b3390352ff..524e111662 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -3729,32 +3729,43 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - const float sf0 = (float)ne0/op->src[0]->ne[0]; - const float sf1 = (float)ne1/op->src[0]->ne[1]; - const float sf2 = (float)ne2/op->src[0]->ne[2]; - const float sf3 = (float)ne3/op->src[0]->ne[3]; + float sf0 = (float)ne0/op->src[0]->ne[0]; + float sf1 = (float)ne1/op->src[0]->ne[1]; + float sf2 = (float)ne2/op->src[0]->ne[2]; + float sf3 = (float)ne3/op->src[0]->ne[3]; + + const int32_t mode_flags = ggml_get_op_params_i32(op, 0); + + float poffs = 0.5f; + + if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { + poffs = 0.0f; + sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0; + sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1; + } ggml_metal_kargs_upscale args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - /*.sf0 =*/ sf0, - /*.sf1 =*/ sf1, - /*.sf2 =*/ sf2, - /*.sf3 =*/ sf3 + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.sf0 =*/ sf0, + /*.sf1 =*/ sf1, + /*.sf2 =*/ sf2, + /*.sf3 =*/ sf3, + /*.poffs =*/ poffs, }; auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index a58e641ad8..5cfd69dd86 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4530,7 +4530,9 @@ kernel void kernel_conv_transpose_2d( uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]); -kernel void kernel_upscale_f32( +constant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]]; + +kernel void kernel_upscale_nearest_f32( constant ggml_metal_kargs_upscale & args, device const char * src0, device char * dst, @@ -4556,6 +4558,156 @@ kernel void kernel_upscale_f32( } } +static inline float bilinear_tri(float x) { + return MAX(0.0f, 1.0f - fabs(x)); +} + +kernel void kernel_upscale_bilinear_f32( + constant ggml_metal_kargs_upscale & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3 / args.sf3; + const int64_t i02 = i2 / args.sf2; + + const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs; + const int64_t i01 = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01))); + const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1)); + const float fd1 = MAX(0.0f, MIN(1.0f, f01 - (float)i01)); + + src0 += i03*args.nb03 + i02*args.nb02; + + device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); + + if (FC_upscale_aa) { + const float support0 = MAX(1.0f, 1.0f / args.sf0); + const float invscale0 = 1.0f / support0; + const float support1 = MAX(1.0f, 1.0f / args.sf1); + const float invscale1 = 1.0f / support1; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs; + + int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs)); + int64_t x_max = MIN(args.ne00, (int64_t)ceil (f00 + support0 + args.poffs)); + + int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs)); + int64_t y_max = MIN(args.ne01, (int64_t)ceil (f01 + support1 + args.poffs)); + + float sum = 0.0f; + float wsum = 0.0f; + + for (int64_t sy = y_min; sy < y_max; ++sy) { + const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1); + for (int64_t sx = x_min; sx < x_max; ++sx) { + const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0); + const float w = wx * wy; + const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00); + sum += (*src_ptr) * w; + wsum += w; + } + } + + const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f; + dst_ptr[i0] = v; + } + } else { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs; + const int64_t i00 = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00))); + const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1)); + const float fd0 = MAX(0.0f, MIN(1.0f, f00 - (float)i00)); + + device const float * src00 = (device const float *)(src0 + i01*args.nb01 + i00*args.nb00); + device const float * src10 = (device const float *)(src0 + i01*args.nb01 + i00p*args.nb00); + device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00); + device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00); + + const float v = + (*src00) * (1.0f - fd0) * (1.0f - fd1) + + (*src10) * fd0 * (1.0f - fd1) + + (*src01) * (1.0f - fd0) * fd1 + + (*src11) * fd0 * fd1; + + dst_ptr[i0] = v; + } + } +} + +static inline float bicubic_weight1(float x) { + const float a = -0.75f; + return ((a + 2) * x - (a + 3)) * x * x + 1; +} + +static inline float bicubic_weight2(float x) { + const float a = -0.75f; + return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; +} + +kernel void kernel_upscale_bicubic_f32( + constant ggml_metal_kargs_upscale & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3 / args.sf3; + const int64_t i02 = i2 / args.sf2; + + const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs; + const int64_t i01 = (int64_t)floor(f01); + const float fd1 = f01 - (float)i01; + + const float w_y0 = bicubic_weight2(fd1 + 1.0f); + const float w_y1 = bicubic_weight1(fd1); + const float w_y2 = bicubic_weight1(1.0f - fd1); + const float w_y3 = bicubic_weight2(2.0f - fd1); + + const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02; + + device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1); + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs; + const int64_t i00 = (int64_t)floor(f00); + const float fd0 = f00 - (float)i00; + + const float w_x0 = bicubic_weight2(fd0 + 1.0f); + const float w_x1 = bicubic_weight1(fd0); + const float w_x2 = bicubic_weight1(1.0f - fd0); + const float w_x3 = bicubic_weight2(2.0f - fd0); + + float sum = 0.0f; + + for (int dy = -1; dy <= 2; ++dy) { + const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy)); + const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3; + + for (int dx = -1; dx <= 2; ++dx) { + const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx)); + const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3; + + const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00); + sum += (*src_ptr) * wx * wy; + } + } + + dst_ptr[i0] = sum; + } +} + kernel void kernel_pad_f32( constant ggml_metal_kargs_pad & args, device const char * src0,