add kernels, add pipeline header declaration
This commit is contained in:
parent
54523493b5
commit
d9b5b17411
|
|
@ -116,6 +116,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_bl
|
|||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
|
|
|||
|
|
@ -910,14 +910,14 @@ typedef struct {
|
|||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t args;
|
||||
int32_t nrows;
|
||||
int32_t k;
|
||||
} ggm_metal_kargs_cross_entropy_loss;
|
||||
} ggml_metal_kargs_cross_entropy_loss;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t args;
|
||||
} ggm_metal_kargs_cross_entropy_loss_back;
|
||||
int32_t nrows;
|
||||
} ggml_metal_kargs_cross_entropy_loss_back;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
|
|
|
|||
|
|
@ -2285,7 +2285,6 @@ kernel void kernel_cross_entropy_loss(
|
|||
constant ggml_metal_kargs_cross_entropy & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
|
|
@ -2412,6 +2411,14 @@ kernel void kernel_cross_entropy_loss_back(
|
|||
|
||||
}
|
||||
|
||||
typedef decltype(kernel_cross_entropy_loss<float>) kernel_cross_entropy_loss_t;
|
||||
typedef decltype(kernel_cross_entropy_loss_back<float>) kernel_cross_entropy_loss_back_t;
|
||||
|
||||
template [[host_name("kernel_cross_entropy_loss_f32")]]
|
||||
kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss<float>;
|
||||
|
||||
template [[host_name("kernel_cross_entropy_loss_back_f32")]]
|
||||
kernel kernel_cross_entropy_loss_back_t kernel_cross_entropy_loss_back<float>;
|
||||
|
||||
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
|
||||
kernel void kernel_ssm_conv_f32_f32(
|
||||
|
|
|
|||
Loading…
Reference in New Issue