diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 680904d132..5bf41228f2 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -384,6 +384,27 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_me return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(!op->src[0] || op->src[0]->type == GGML_TYPE_F32); + + char base[256]; + char name[256]; + + const ggml_type tsrc1 = GGML_TYPE_F32; + + snprintf(base, 256, "kernel_cross_entropy_loss_%s", ggml_type_name(tsrc1)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + res.smem = 32*sizeof(float); + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 0a8b9211a7..c70d59101d 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -116,6 +116,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_bl struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cross_entropy (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 8944b07e90..d8422490c5 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -908,6 +908,17 @@ typedef struct { int64_t np; } ggml_metal_kargs_pool_2d; +typedef struct { + int32_t ne00; + int32_t nrows; + int32_t k; +} ggml_metal_kargs_cross_entropy_loss; + +typedef struct { + int32_t ne00; + int32_t nrows; +} ggml_metal_kargs_cross_entropy_loss_back; + typedef struct { int64_t ne00; uint64_t nb01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e99c1763f6..b27fee639a 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1333,6 +1333,46 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_cross_entropy_loss(ggml_metal_op_t ctx, int idx){ + ggml_tensor * op = ctx->node(idx); + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + const ggml_tensor * src0 = op->src[0]; // NOTE: logits + const ggml_tensor * src1 = op->src[1]; // NOTE: labels + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const int32_t ne00 = src0->ne[0]; + const int32_t nrows = ggml_nrows(src1); + ggml_metal_kargs_cross_entropy_loss args = { + /*int32_t*/ ne00, + /*int32_t*/ nrows, + /*int32_t*/ nrows, + }; + int nth = 32; + auto pipeline = ggml_metal_library_get_pipeline_cross_entropy(lib, op); + + const size_t smem = pipeline.smem; + + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + if (op->src[1]) { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + } else { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 2); + } + + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 4); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne00, nrows, nrows, nth, 1, 1); + return 1; +} + int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 902b544523..7ed76cbf0b 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -57,6 +57,7 @@ int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx); int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_cross_entropy_loss(ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 51bcbae309..e8b357a326 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2280,6 +2280,146 @@ template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kerne template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; +template +kernel void kernel_cross_entropy_loss( + constant ggml_metal_kargs_cross_entropy_loss & args, + device const float * logits, + device const float * labels, + device float * dst, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint tptg[[threads_per_threadgroup]]) { + + const int row = tgpig; + device const float * logits_row = logits + row * args.ne00; + device const float * labels_row = labels + row * args.ne00; + + float lmax = - INFINITY; + for (int i = tpitg; i < args.ne00; i+= tptg){ + lmax = MAX(lmax, logits_row[i]); + } + float max_val = simd_max(lmax); + if (tptg > N_SIMDWIDTH) { + if (sgitg == 0) buf[tiisg] = -INFINITY; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (tiisg == 0) buf[sgitg] = max_val; + threadgroup_barrier(mem_flags::mem_threadgroup); + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + float lsum = 0.0f; + for (int i = tpitg; i < args.ne00; i += tptg){ + const float exp_val = exp(logits_row[i] - max_val); + lsum += exp_val; + dst[i] = exp_val; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float sum = simd_sum(lsum); + if (tptg > N_SIMDWIDTH){ + if (sgitg == 0) buf[tiisg] = 0.0f; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (tiisg == 0) buf[sgitg] = sum; + threadgroup_barrier(mem_flags::mem_threadgroup); + sum = buf[tiisg]; + sum = simd_sum(sum); + } + const float log_sum = log(sum); + float lloss = 0.0f; + for (int i = tpitg; i < args.ne00; i += tptg){ + const float log_softmax_i = logits_row[i] - max_val - log_sum; + lloss += log_softmax_i * logits_row[i]; + } + + float loss = simd_sum(lloss); + if (tptg > N_SIMDWIDTH) { + if (sgitg == 0) buf[tiisg] = 0.0f; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (tiisg == 0) buf[sgitg] = loss; + threadgroup_barrier(mem_flags::mem_threadgroup); + loss = buf[tiisg]; + loss = simd_sum(loss); + } + if (tpitg == 0) { + dst[row] = -loss / args.nrows; + } +} + + +template +kernel void kernel_cross_entropy_loss_back( + constant ggml_metal_kargs_cross_entropy_loss_back & args, + device const float * grad, + device const float * logits, // src0 + device const float * labels, // src1 + device float * dst, // output + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint tptg[[threads_per_threadgroup]]){ + + const int row = tpitg; + + device const float * logits_row = logits + row * args.ne00; + device const float * labels_row = labels + row * args.ne00; + device float * dst_row = dst + row * args.ne00; + + // find max + + float lmax = - INFINITY; + for (int i = tpitg; i < args.ne00; i+= tptg){ + lmax = MAX(lmax, logits_row[i]); + } + float max_val = simd_max(lmax); + if (tptg > N_SIMDWIDTH) { + if (sgitg == 0) buf[tiisg] = -INFINITY; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (tiisg == 0) buf[sgitg] = max_val; + threadgroup_barrier(mem_flags::mem_threadgroup); + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + float lsum = 0.0f; + for (int i = tpitg; i < args.ne00; i += tptg){ + const float exp_val = exp(logits_row[i] - max_val); + lsum += exp_val; + dst_row[i] = exp_val; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float sum = simd_sum(lsum); + if (tptg > N_SIMDWIDTH){ + if (sgitg == 0) buf[tiisg] = 0.0f; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (tiisg == 0) buf[sgitg] = sum; + threadgroup_barrier(mem_flags::mem_threadgroup); + sum = buf[tiisg]; + sum = simd_sum(sum); + } + const float inv_sum = 1.0f / sum; + const float d_by_nrows = grad[0] / args.nrows; + + for (int i = tpitg; i < args.ne00; i += tptg){ + const float softmax_i = dst_row[i] * inv_sum; // exp(logits - max)/ sum(exp(logits - val)) + dst_row[i] = (softmax_i - labels_row[i]) * d_by_nrows; + } + +} + +typedef decltype(kernel_cross_entropy_loss) kernel_cross_entropy_loss_t; +typedef decltype(kernel_cross_entropy_loss_back) kernel_cross_entropy_loss_back_t; + +template [[host_name("kernel_cross_entropy_loss_f32")]] +kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss; + +template [[host_name("kernel_cross_entropy_loss_back_f32")]] +kernel kernel_cross_entropy_loss_back_t kernel_cross_entropy_loss_back; + // ref: ggml.c:ggml_compute_forward_ssm_conv_f32 kernel void kernel_ssm_conv_f32_f32( constant ggml_metal_kargs_ssm_conv & args,