working on ops function and pipeline
This commit is contained in:
parent
d9b5b17411
commit
fc11fd3ff4
|
|
@ -384,6 +384,27 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_me
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||||
|
GGML_ASSERT(!op->src[0] || op->src[0]->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
char base[256];
|
||||||
|
char name[256];
|
||||||
|
|
||||||
|
const ggml_type tsrc1 = GGML_TYPE_F32;
|
||||||
|
|
||||||
|
snprintf(base, 256, "kernel_cross_entropy_loss_%s", ggml_type_name(tsrc1));
|
||||||
|
snprintf(name, 256, "%s", base);
|
||||||
|
|
||||||
|
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||||
|
if (!res.pipeline) {
|
||||||
|
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
res.smem = 32*sizeof(float);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) {
|
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||||
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||||
|
|
|
||||||
|
|
@ -1334,14 +1334,42 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int ggml_metal_op_cross_entropy_loss(ggml_metal_op_t ctx, int idx){
|
int ggml_metal_op_cross_entropy_loss(ggml_metal_op_t ctx, int idx){
|
||||||
const ggml_tensor * src0 = ctx->node(idx)->src[0]; // NOTE: logits
|
ggml_tensor * op = ctx->node(idx);
|
||||||
const ggml_tensor * src1 = ctx->node(idx)->src[1]; // NOTE: labels
|
ggml_metal_library_t lib = ctx->lib;
|
||||||
|
ggml_metal_encoder_t enc = ctx->enc;
|
||||||
|
const ggml_tensor * src0 = op->src[0]; // NOTE: logits
|
||||||
|
const ggml_tensor * src1 = op->src[1]; // NOTE: labels
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||||
|
|
||||||
|
const int32_t ne00 = src0->ne[0];
|
||||||
|
const int32_t nrows = ggml_nrows(src1);
|
||||||
|
ggml_metal_kargs_cross_entropy_loss args = {
|
||||||
|
/*int32_t*/ ne00,
|
||||||
|
/*int32_t*/ nrows,
|
||||||
|
/*int32_t*/ nrows,
|
||||||
|
};
|
||||||
|
int nth = 32;
|
||||||
|
auto pipeline = ggml_metal_library_get_pipeline_cross_entropy(lib, op);
|
||||||
|
|
||||||
|
const size_t smem = pipeline.smem;
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
||||||
|
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||||
|
if (op->src[1]) {
|
||||||
|
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||||
|
} else {
|
||||||
|
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 4);
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||||
|
|
||||||
|
ggml_metal_encoder_dispatch_threadgroups(enc, ne00, nrows, nrows, nth, 1, 1);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue