diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 949e344cc8..517559d12a 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1480,13 +1480,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_met ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_L2_NORM); - GGML_ASSERT(op->src[0]->ne[0] % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); - char base[256]; char name[256]; - snprintf(base, 256, "kernel_l2_norm_f32"); + const bool is_c4 = op->src[0]->ne[0] % 4 == 0; + + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t_str = ggml_type_name(op->type); + + snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : ""); snprintf(name, 256, "%s", base); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); @@ -1494,6 +1496,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } + res.c4 = is_c4; res.smem = 32*sizeof(float); return res; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 50a2a3e7f7..c714ef3add 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1086,9 +1086,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_MEAN: case GGML_OP_SOFT_MAX: case GGML_OP_GROUP_NORM: - return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_L2_NORM: - return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); + return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_COUNT_EQUAL: return has_simdgroup_reduction && op->src[0]->type == GGML_TYPE_I32 && diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 44141f8e3d..952e1be076 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -539,8 +539,21 @@ typedef struct { typedef struct { int32_t ne00; - int32_t ne00_4; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; float eps; } ggml_metal_kargs_l2_norm; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index b159a8e7fd..7db95d1c84 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2979,39 +2979,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + float eps; memcpy(&eps, op->op_params, sizeof(float)); - int nth = 32; // SIMD width - ggml_metal_kargs_l2_norm args = { - /*.ne00 =*/ ne00, - /*.ne00_4 =*/ ne00/4, - /*.nb01 =*/ nb01, - /*.eps =*/ eps, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.eps =*/ eps, }; auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op); - while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + + int nth = 32; // SIMD width + + while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nth *= 2; } nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - nth = std::min(nth, ne00/4); const size_t smem = pipeline.smem; - const int64_t nrows = ggml_nrows(op->src[0]); - ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 7d841341a1..a385a50b94 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2706,26 +2706,32 @@ template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_f template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; -kernel void kernel_l2_norm_f32( +template +kernel void kernel_l2_norm_impl( constant ggml_metal_kargs_l2_norm & args, device const char * src0, device char * dst, threadgroup float * shmem_f32 [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - ushort tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; + if (sgitg == 0) { shmem_f32[tiisg] = 0.0f; } - device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); float sumf = 0.0f; // parallel sum - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { sumf += dot(x[i00], x[i00]); } sumf = simd_sum(sumf); @@ -2743,12 +2749,16 @@ kernel void kernel_l2_norm_f32( const float scale = 1.0f/sqrt(max(sumf, args.eps)); - device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { y[i00] = x[i00] * scale; } } +typedef decltype(kernel_l2_norm_impl) kernel_l2_norm_t; + +template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl; +template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl; + kernel void kernel_group_norm_f32( constant ggml_metal_kargs_group_norm & args, device const float * src0,