add kernels, add pipeline header declaration

This commit is contained in:
Ilia Ilmer 2025-12-16 21:59:58 -05:00
parent 54523493b5
commit d9b5b17411
No known key found for this signature in database
3 changed files with 13 additions and 5 deletions

View File

@ -116,6 +116,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_bl
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);

View File

@ -910,14 +910,14 @@ typedef struct {
typedef struct {
int32_t ne00;
int32_t args;
int32_t nrows;
int32_t k;
} ggm_metal_kargs_cross_entropy_loss;
} ggml_metal_kargs_cross_entropy_loss;
typedef struct {
int32_t ne00;
int32_t args;
} ggm_metal_kargs_cross_entropy_loss_back;
int32_t nrows;
} ggml_metal_kargs_cross_entropy_loss_back;
typedef struct {
int64_t ne00;

View File

@ -2285,7 +2285,6 @@ kernel void kernel_cross_entropy_loss(
constant ggml_metal_kargs_cross_entropy & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
@ -2412,6 +2411,14 @@ kernel void kernel_cross_entropy_loss_back(
}
typedef decltype(kernel_cross_entropy_loss<float>) kernel_cross_entropy_loss_t;
typedef decltype(kernel_cross_entropy_loss_back<float>) kernel_cross_entropy_loss_back_t;
template [[host_name("kernel_cross_entropy_loss_f32")]]
kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss<float>;
template [[host_name("kernel_cross_entropy_loss_back_f32")]]
kernel kernel_cross_entropy_loss_back_t kernel_cross_entropy_loss_back<float>;
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
kernel void kernel_ssm_conv_f32_f32(