From 54523493b5ca0b53fa5cca212b362723ff8f46fd Mon Sep 17 00:00:00 2001 From: Ilia Ilmer Date: Tue, 16 Dec 2025 19:10:17 -0500 Subject: [PATCH] starting on cross entropy: kernels --- ggml/src/ggml-metal/ggml-metal-impl.h | 11 ++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 12 +++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.metal | 133 +++++++++++++++++++++++++ 4 files changed, 157 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 8944b07e90..a9967d5779 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -908,6 +908,17 @@ typedef struct { int64_t np; } ggml_metal_kargs_pool_2d; +typedef struct { + int32_t ne00; + int32_t args; + int32_t k; +} ggm_metal_kargs_cross_entropy_loss; + +typedef struct { + int32_t ne00; + int32_t args; +} ggm_metal_kargs_cross_entropy_loss_back; + typedef struct { int64_t ne00; uint64_t nb01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e99c1763f6..44ec0ddbd7 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1333,6 +1333,18 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_cross_entropy_loss(ggml_metal_op_t ctx, int idx){ + const ggml_tensor * src0 = ctx->node(idx)->src[0]; // NOTE: logits + const ggml_tensor * src1 = ctx->node(idx)->src[1]; // NOTE: labels + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + return 1; +} + int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 902b544523..7ed76cbf0b 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -57,6 +57,7 @@ int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx); int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_cross_entropy_loss(ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 51bcbae309..1108dd22ad 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2280,6 +2280,139 @@ template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kerne template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; +template +kernel void kernel_cross_entropy_loss( + constant ggml_metal_kargs_cross_entropy & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint tptg[[threads_per_threadgroup]]) { + + const int row = tgpig; + device const float * logits_row = logits + row * args.ne00; + device const float * labels_row = labels + row * args.ne00; + + float lmax = - INFINITY; + for (int i = tpitg; i < args.ne00; i+= tptg){ + lmax = MAX(lmax, logits_row[i]); + } + float max_val = simd_max(lmax); + if (tptg > N_SIMDWIDTH) { + if (sgitg == 0) buf[tiisg] = -INFINITY; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (tiisg == 0) buf[sgitg] = max_val; + threadgroup_barrier(mem_flags::mem_threadgroup); + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + float lsum = 0.0f; + for (int i = tpitg; i < args.ne00; i += tptg){ + const float exp_val = exp(logits_row[i] - max_val); + lsum += exp_val; + dst_row[i] = exp_val; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float sum = simd_sum(lsum); + if (tptg > N_SIMDWIDTH){ + if (sgitg == 0) buf[tiisg] = 0.0f; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (tiisg == 0) buf[sgitg] = sum; + threadgroup_barrier(mem_flags::mem_threadgroup); + sum = buf[tiisg]; + sum = simd_sum(sum); + } + const float log_sum = log(sum); + float lloss = 0.0f; + for (int i = tpitg; i < args.ne00; i += tptg){ + const float log_softmax_i = logits_row[i] - max_val - log_sum; + lloss += log_softmax_i * logits_row[i]; + } + + float loss = simd_sum(lloss); + if (tptg > N_SIMDWIDTH) { + if (sgitg == 0) buf[tiisg] = 0.0f; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (tiisg == 0) buf[sgitg] = loss; + threadgroup_barrier(mem_flags::mem_threadgroup); + loss = buf[tiisg]; + loss = simd_sum(loss); + } + if (tpitg == 0) { + dst[i] = -loss / args.nrows; + } +} + + +template +kernel void kernel_cross_entropy_loss_back( + constant ggml_metal_kargs_cross_entropy_loss_back & args, + device const float * grad, + device const float * logits, // src0 + device const float * labels, // src1 + device float * dst, // output + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint tptg[[threads_per_threadgroup]]){ + + const int row = tpitg; + + device const float * logits_row = logits + row * args.ne00; + device const float * labels_row = labels + row * args.ne00; + device float * dst_row = dst + row * args.ne00; + + // find max + + float lmax = - INFINITY; + for (int i = tpitg; i < args.ne00; i+= tptg){ + lmax = MAX(lmax, logits_row[i]); + } + float max_val = simd_max(lmax); + if (tptg > N_SIMDWIDTH) { + if (sgitg == 0) buf[tiisg] = -INFINITY; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (tiisg == 0) buf[sgitg] = max_val; + threadgroup_barrier(mem_flags::mem_threadgroup); + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + float lsum = 0.0f; + for (int i = tpitg; i < args.ne00; i += tptg){ + const float exp_val = exp(logits_row[i] - max_val); + lsum += exp_val; + dst_row[i] = exp_val; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float sum = simd_sum(lsum); + if (tptg > N_SIMDWIDTH){ + if (sgitg == 0) buf[tiisg] = 0.0f; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (tiisg == 0) buf[sgitg] = sum; + threadgroup_barrier(mem_flags::mem_threadgroup); + sum = buf[tiisg]; + sum = simd_sum(sum); + } + const float inv_sum = 1.0f / sum; + const float d_by_nrows = grad[0] / args.nrows; + + for (int i = tpitg; i < args.ne00; i += tptg){ + const float softmax_i = dst_row[i] * inv_sum; // exp(logits - max)/ sum(exp(logits - val)) + dst_row[i] = (softmax_i - labels_row[i]) * d_by_nrows; + } + +} + + // ref: ggml.c:ggml_compute_forward_ssm_conv_f32 kernel void kernel_ssm_conv_f32_f32( constant ggml_metal_kargs_ssm_conv & args,