From fc11fd3ff451d394f02e3a98487f01c14a8b7356 Mon Sep 17 00:00:00 2001 From: Ilia Ilmer Date: Tue, 16 Dec 2025 22:00:30 -0500 Subject: [PATCH] working on ops function and pipeline --- ggml/src/ggml-metal/ggml-metal-device.cpp | 21 +++++++++++++++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 32 +++++++++++++++++++++-- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 680904d132..5bf41228f2 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -384,6 +384,27 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_me return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(!op->src[0] || op->src[0]->type == GGML_TYPE_F32); + + char base[256]; + char name[256]; + + const ggml_type tsrc1 = GGML_TYPE_F32; + + snprintf(base, 256, "kernel_cross_entropy_loss_%s", ggml_type_name(tsrc1)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + res.smem = 32*sizeof(float); + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 44ec0ddbd7..b27fee639a 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1334,14 +1334,42 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_cross_entropy_loss(ggml_metal_op_t ctx, int idx){ - const ggml_tensor * src0 = ctx->node(idx)->src[0]; // NOTE: logits - const ggml_tensor * src1 = ctx->node(idx)->src[1]; // NOTE: labels + ggml_tensor * op = ctx->node(idx); + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + const ggml_tensor * src0 = op->src[0]; // NOTE: logits + const ggml_tensor * src1 = op->src[1]; // NOTE: labels GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src1)); + const int32_t ne00 = src0->ne[0]; + const int32_t nrows = ggml_nrows(src1); + ggml_metal_kargs_cross_entropy_loss args = { + /*int32_t*/ ne00, + /*int32_t*/ nrows, + /*int32_t*/ nrows, + }; + int nth = 32; + auto pipeline = ggml_metal_library_get_pipeline_cross_entropy(lib, op); + + const size_t smem = pipeline.smem; + + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + if (op->src[1]) { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + } else { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 2); + } + + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 4); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne00, nrows, nrows, nth, 1, 1); return 1; }