From 54523493b5ca0b53fa5cca212b362723ff8f46fd Mon Sep 17 00:00:00 2001 From: Ilia Ilmer Date: Tue, 16 Dec 2025 19:10:17 -0500 Subject: [PATCH 1/4] starting on cross entropy: kernels --- ggml/src/ggml-metal/ggml-metal-impl.h | 11 ++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 12 +++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.metal | 133 +++++++++++++++++++++++++ 4 files changed, 157 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 8944b07e90..a9967d5779 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -908,6 +908,17 @@ typedef struct { int64_t np; } ggml_metal_kargs_pool_2d; +typedef struct { + int32_t ne00; + int32_t args; + int32_t k; +} ggm_metal_kargs_cross_entropy_loss; + +typedef struct { + int32_t ne00; + int32_t args; +} ggm_metal_kargs_cross_entropy_loss_back; + typedef struct { int64_t ne00; uint64_t nb01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e99c1763f6..44ec0ddbd7 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1333,6 +1333,18 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_cross_entropy_loss(ggml_metal_op_t ctx, int idx){ + const ggml_tensor * src0 = ctx->node(idx)->src[0]; // NOTE: logits + const ggml_tensor * src1 = ctx->node(idx)->src[1]; // NOTE: labels + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + return 1; +} + int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 902b544523..7ed76cbf0b 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -57,6 +57,7 @@ int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx); int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_cross_entropy_loss(ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 51bcbae309..1108dd22ad 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2280,6 +2280,139 @@ template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kerne template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; +template +kernel void kernel_cross_entropy_loss( + constant ggml_metal_kargs_cross_entropy & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint tptg[[threads_per_threadgroup]]) { + + const int row = tgpig; + device const float * logits_row = logits + row * args.ne00; + device const float * labels_row = labels + row * args.ne00; + + float lmax = - INFINITY; + for (int i = tpitg; i < args.ne00; i+= tptg){ + lmax = MAX(lmax, logits_row[i]); + } + float max_val = simd_max(lmax); + if (tptg > N_SIMDWIDTH) { + if (sgitg == 0) buf[tiisg] = -INFINITY; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (tiisg == 0) buf[sgitg] = max_val; + threadgroup_barrier(mem_flags::mem_threadgroup); + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + float lsum = 0.0f; + for (int i = tpitg; i < args.ne00; i += tptg){ + const float exp_val = exp(logits_row[i] - max_val); + lsum += exp_val; + dst_row[i] = exp_val; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float sum = simd_sum(lsum); + if (tptg > N_SIMDWIDTH){ + if (sgitg == 0) buf[tiisg] = 0.0f; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (tiisg == 0) buf[sgitg] = sum; + threadgroup_barrier(mem_flags::mem_threadgroup); + sum = buf[tiisg]; + sum = simd_sum(sum); + } + const float log_sum = log(sum); + float lloss = 0.0f; + for (int i = tpitg; i < args.ne00; i += tptg){ + const float log_softmax_i = logits_row[i] - max_val - log_sum; + lloss += log_softmax_i * logits_row[i]; + } + + float loss = simd_sum(lloss); + if (tptg > N_SIMDWIDTH) { + if (sgitg == 0) buf[tiisg] = 0.0f; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (tiisg == 0) buf[sgitg] = loss; + threadgroup_barrier(mem_flags::mem_threadgroup); + loss = buf[tiisg]; + loss = simd_sum(loss); + } + if (tpitg == 0) { + dst[i] = -loss / args.nrows; + } +} + + +template +kernel void kernel_cross_entropy_loss_back( + constant ggml_metal_kargs_cross_entropy_loss_back & args, + device const float * grad, + device const float * logits, // src0 + device const float * labels, // src1 + device float * dst, // output + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint tptg[[threads_per_threadgroup]]){ + + const int row = tpitg; + + device const float * logits_row = logits + row * args.ne00; + device const float * labels_row = labels + row * args.ne00; + device float * dst_row = dst + row * args.ne00; + + // find max + + float lmax = - INFINITY; + for (int i = tpitg; i < args.ne00; i+= tptg){ + lmax = MAX(lmax, logits_row[i]); + } + float max_val = simd_max(lmax); + if (tptg > N_SIMDWIDTH) { + if (sgitg == 0) buf[tiisg] = -INFINITY; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (tiisg == 0) buf[sgitg] = max_val; + threadgroup_barrier(mem_flags::mem_threadgroup); + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + float lsum = 0.0f; + for (int i = tpitg; i < args.ne00; i += tptg){ + const float exp_val = exp(logits_row[i] - max_val); + lsum += exp_val; + dst_row[i] = exp_val; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float sum = simd_sum(lsum); + if (tptg > N_SIMDWIDTH){ + if (sgitg == 0) buf[tiisg] = 0.0f; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (tiisg == 0) buf[sgitg] = sum; + threadgroup_barrier(mem_flags::mem_threadgroup); + sum = buf[tiisg]; + sum = simd_sum(sum); + } + const float inv_sum = 1.0f / sum; + const float d_by_nrows = grad[0] / args.nrows; + + for (int i = tpitg; i < args.ne00; i += tptg){ + const float softmax_i = dst_row[i] * inv_sum; // exp(logits - max)/ sum(exp(logits - val)) + dst_row[i] = (softmax_i - labels_row[i]) * d_by_nrows; + } + +} + + // ref: ggml.c:ggml_compute_forward_ssm_conv_f32 kernel void kernel_ssm_conv_f32_f32( constant ggml_metal_kargs_ssm_conv & args, From d9b5b1741118b8e7c2f3ba650f7eb93ffb304250 Mon Sep 17 00:00:00 2001 From: Ilia Ilmer Date: Tue, 16 Dec 2025 21:59:58 -0500 Subject: [PATCH 2/4] add kernels, add pipeline header declaration --- ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-impl.h | 8 ++++---- ggml/src/ggml-metal/ggml-metal.metal | 9 ++++++++- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 0a8b9211a7..c70d59101d 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -116,6 +116,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_bl struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index a9967d5779..d8422490c5 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -910,14 +910,14 @@ typedef struct { typedef struct { int32_t ne00; - int32_t args; + int32_t nrows; int32_t k; -} ggm_metal_kargs_cross_entropy_loss; +} ggml_metal_kargs_cross_entropy_loss; typedef struct { int32_t ne00; - int32_t args; -} ggm_metal_kargs_cross_entropy_loss_back; + int32_t nrows; +} ggml_metal_kargs_cross_entropy_loss_back; typedef struct { int64_t ne00; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 1108dd22ad..36074cdafa 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2285,7 +2285,6 @@ kernel void kernel_cross_entropy_loss( constant ggml_metal_kargs_cross_entropy & args, device const char * src0, device const char * src1, - device const char * src2, device char * dst, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], @@ -2412,6 +2411,14 @@ kernel void kernel_cross_entropy_loss_back( } +typedef decltype(kernel_cross_entropy_loss) kernel_cross_entropy_loss_t; +typedef decltype(kernel_cross_entropy_loss_back) kernel_cross_entropy_loss_back_t; + +template [[host_name("kernel_cross_entropy_loss_f32")]] +kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss; + +template [[host_name("kernel_cross_entropy_loss_back_f32")]] +kernel kernel_cross_entropy_loss_back_t kernel_cross_entropy_loss_back; // ref: ggml.c:ggml_compute_forward_ssm_conv_f32 kernel void kernel_ssm_conv_f32_f32( From fc11fd3ff451d394f02e3a98487f01c14a8b7356 Mon Sep 17 00:00:00 2001 From: Ilia Ilmer Date: Tue, 16 Dec 2025 22:00:30 -0500 Subject: [PATCH 3/4] working on ops function and pipeline --- ggml/src/ggml-metal/ggml-metal-device.cpp | 21 +++++++++++++++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 32 +++++++++++++++++++++-- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 680904d132..5bf41228f2 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -384,6 +384,27 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_me return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(!op->src[0] || op->src[0]->type == GGML_TYPE_F32); + + char base[256]; + char name[256]; + + const ggml_type tsrc1 = GGML_TYPE_F32; + + snprintf(base, 256, "kernel_cross_entropy_loss_%s", ggml_type_name(tsrc1)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + res.smem = 32*sizeof(float); + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 44ec0ddbd7..b27fee639a 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1334,14 +1334,42 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_cross_entropy_loss(ggml_metal_op_t ctx, int idx){ - const ggml_tensor * src0 = ctx->node(idx)->src[0]; // NOTE: logits - const ggml_tensor * src1 = ctx->node(idx)->src[1]; // NOTE: labels + ggml_tensor * op = ctx->node(idx); + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + const ggml_tensor * src0 = op->src[0]; // NOTE: logits + const ggml_tensor * src1 = op->src[1]; // NOTE: labels GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src1)); + const int32_t ne00 = src0->ne[0]; + const int32_t nrows = ggml_nrows(src1); + ggml_metal_kargs_cross_entropy_loss args = { + /*int32_t*/ ne00, + /*int32_t*/ nrows, + /*int32_t*/ nrows, + }; + int nth = 32; + auto pipeline = ggml_metal_library_get_pipeline_cross_entropy(lib, op); + + const size_t smem = pipeline.smem; + + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + if (op->src[1]) { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + } else { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 2); + } + + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 4); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne00, nrows, nrows, nth, 1, 1); return 1; } From 646038f7765f8527d62735fb7676de019d4c6c44 Mon Sep 17 00:00:00 2001 From: Ilia Ilmer Date: Tue, 16 Dec 2025 22:12:15 -0500 Subject: [PATCH 4/4] fix bugs/typos etc. --- ggml/src/ggml-metal/ggml-metal.metal | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 36074cdafa..e8b357a326 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2282,10 +2282,10 @@ template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kerne template kernel void kernel_cross_entropy_loss( - constant ggml_metal_kargs_cross_entropy & args, - device const char * src0, - device const char * src1, - device char * dst, + constant ggml_metal_kargs_cross_entropy_loss & args, + device const float * logits, + device const float * labels, + device float * dst, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], @@ -2315,7 +2315,7 @@ kernel void kernel_cross_entropy_loss( for (int i = tpitg; i < args.ne00; i += tptg){ const float exp_val = exp(logits_row[i] - max_val); lsum += exp_val; - dst_row[i] = exp_val; + dst[i] = exp_val; } threadgroup_barrier(mem_flags::mem_threadgroup); float sum = simd_sum(lsum); @@ -2344,7 +2344,7 @@ kernel void kernel_cross_entropy_loss( loss = simd_sum(loss); } if (tpitg == 0) { - dst[i] = -loss / args.nrows; + dst[row] = -loss / args.nrows; } }