metal : extend l2_norm support for non-cont src0 (#19502)

This commit is contained in:
Georgi Gerganov 2026-02-11 14:53:19 +02:00 committed by GitHub
parent ada90bf2ba
commit 9ab072ebbe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 75 additions and 30 deletions

View File

@ -1480,13 +1480,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_met
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) { ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_L2_NORM); assert(op->op == GGML_OP_L2_NORM);
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
char base[256]; char base[256];
char name[256]; char name[256];
snprintf(base, 256, "kernel_l2_norm_f32"); const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
const char * t0_str = ggml_type_name(op->src[0]->type);
const char * t_str = ggml_type_name(op->type);
snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s", base); snprintf(name, 256, "%s", base);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
@ -1494,6 +1496,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
} }
res.c4 = is_c4;
res.smem = 32*sizeof(float); res.smem = 32*sizeof(float);
return res; return res;

View File

@ -1086,9 +1086,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_MEAN: case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM: case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_COUNT_EQUAL: case GGML_OP_COUNT_EQUAL:
return has_simdgroup_reduction && return has_simdgroup_reduction &&
op->src[0]->type == GGML_TYPE_I32 && op->src[0]->type == GGML_TYPE_I32 &&

View File

@ -539,8 +539,21 @@ typedef struct {
typedef struct { typedef struct {
int32_t ne00; int32_t ne00;
int32_t ne00_4; int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01; uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
float eps; float eps;
} ggml_metal_kargs_l2_norm; } ggml_metal_kargs_l2_norm;

View File

@ -2979,39 +2979,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
float eps; float eps;
memcpy(&eps, op->op_params, sizeof(float)); memcpy(&eps, op->op_params, sizeof(float));
int nth = 32; // SIMD width
ggml_metal_kargs_l2_norm args = { ggml_metal_kargs_l2_norm args = {
/*.ne00 =*/ ne00, /*.ne00 =*/ ne00,
/*.ne00_4 =*/ ne00/4, /*.ne01 =*/ ne01,
/*.nb01 =*/ nb01, /*.ne02 =*/ ne02,
/*.eps =*/ eps, /*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.eps =*/ eps,
}; };
auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op); auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne0 = ne0/4;
}
int nth = 32; // SIMD width
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2; nth *= 2;
} }
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00/4);
const size_t smem = pipeline.smem; const size_t smem = pipeline.smem;
const int64_t nrows = ggml_nrows(op->src[0]);
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1); ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
return 1; return 1;
} }

View File

@ -2706,26 +2706,32 @@ template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_f
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>; template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>; template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
kernel void kernel_l2_norm_f32( template <typename T0, typename T>
kernel void kernel_l2_norm_impl(
constant ggml_metal_kargs_l2_norm & args, constant ggml_metal_kargs_l2_norm & args,
device const char * src0, device const char * src0,
device char * dst, device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]], threadgroup float * shmem_f32 [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
ushort tpitg[[thread_position_in_threadgroup]], ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]], ushort tiisg[[thread_index_in_simdgroup]],
ushort ntg[[threads_per_threadgroup]]) { ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int i01 = tgpig.x;
if (sgitg == 0) { if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f; shmem_f32[tiisg] = 0.0f;
} }
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
float sumf = 0.0f; float sumf = 0.0f;
// parallel sum // parallel sum
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]); sumf += dot(x[i00], x[i00]);
} }
sumf = simd_sum(sumf); sumf = simd_sum(sumf);
@ -2743,12 +2749,16 @@ kernel void kernel_l2_norm_f32(
const float scale = 1.0f/sqrt(max(sumf, args.eps)); const float scale = 1.0f/sqrt(max(sumf, args.eps));
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
y[i00] = x[i00] * scale; y[i00] = x[i00] * scale;
} }
} }
typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
kernel void kernel_group_norm_f32( kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args, constant ggml_metal_kargs_group_norm & args,
device const float * src0, device const float * src0,