diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 0a8b9211a7..c70d59101d 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -116,6 +116,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_bl struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index a9967d5779..d8422490c5 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -910,14 +910,14 @@ typedef struct { typedef struct { int32_t ne00; - int32_t args; + int32_t nrows; int32_t k; -} ggm_metal_kargs_cross_entropy_loss; +} ggml_metal_kargs_cross_entropy_loss; typedef struct { int32_t ne00; - int32_t args; -} ggm_metal_kargs_cross_entropy_loss_back; + int32_t nrows; +} ggml_metal_kargs_cross_entropy_loss_back; typedef struct { int64_t ne00; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 1108dd22ad..36074cdafa 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2285,7 +2285,6 @@ kernel void kernel_cross_entropy_loss( constant ggml_metal_kargs_cross_entropy & args, device const char * src0, device const char * src1, - device const char * src2, device char * dst, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], @@ -2412,6 +2411,14 @@ kernel void kernel_cross_entropy_loss_back( } +typedef decltype(kernel_cross_entropy_loss) kernel_cross_entropy_loss_t; +typedef decltype(kernel_cross_entropy_loss_back) kernel_cross_entropy_loss_back_t; + +template [[host_name("kernel_cross_entropy_loss_f32")]] +kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss; + +template [[host_name("kernel_cross_entropy_loss_back_f32")]] +kernel kernel_cross_entropy_loss_back_t kernel_cross_entropy_loss_back; // ref: ggml.c:ggml_compute_forward_ssm_conv_f32 kernel void kernel_ssm_conv_f32_f32(