add kernels, add pipeline header declaration
This commit is contained in:
parent
54523493b5
commit
d9b5b17411
|
|
@ -116,6 +116,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_bl
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
|
||||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
|
|
|
||||||
|
|
@ -910,14 +910,14 @@ typedef struct {
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int32_t ne00;
|
int32_t ne00;
|
||||||
int32_t args;
|
int32_t nrows;
|
||||||
int32_t k;
|
int32_t k;
|
||||||
} ggm_metal_kargs_cross_entropy_loss;
|
} ggml_metal_kargs_cross_entropy_loss;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int32_t ne00;
|
int32_t ne00;
|
||||||
int32_t args;
|
int32_t nrows;
|
||||||
} ggm_metal_kargs_cross_entropy_loss_back;
|
} ggml_metal_kargs_cross_entropy_loss_back;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int64_t ne00;
|
int64_t ne00;
|
||||||
|
|
|
||||||
|
|
@ -2285,7 +2285,6 @@ kernel void kernel_cross_entropy_loss(
|
||||||
constant ggml_metal_kargs_cross_entropy & args,
|
constant ggml_metal_kargs_cross_entropy & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device const char * src2,
|
|
||||||
device char * dst,
|
device char * dst,
|
||||||
threadgroup float * buf [[threadgroup(0)]],
|
threadgroup float * buf [[threadgroup(0)]],
|
||||||
uint tgpig[[threadgroup_position_in_grid]],
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
|
|
@ -2412,6 +2411,14 @@ kernel void kernel_cross_entropy_loss_back(
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
typedef decltype(kernel_cross_entropy_loss<float>) kernel_cross_entropy_loss_t;
|
||||||
|
typedef decltype(kernel_cross_entropy_loss_back<float>) kernel_cross_entropy_loss_back_t;
|
||||||
|
|
||||||
|
template [[host_name("kernel_cross_entropy_loss_f32")]]
|
||||||
|
kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss<float>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_cross_entropy_loss_back_f32")]]
|
||||||
|
kernel kernel_cross_entropy_loss_back_t kernel_cross_entropy_loss_back<float>;
|
||||||
|
|
||||||
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
|
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
|
||||||
kernel void kernel_ssm_conv_f32_f32(
|
kernel void kernel_ssm_conv_f32_f32(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue