diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 36074cdafa..e8b357a326 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2282,10 +2282,10 @@ template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kerne template kernel void kernel_cross_entropy_loss( - constant ggml_metal_kargs_cross_entropy & args, - device const char * src0, - device const char * src1, - device char * dst, + constant ggml_metal_kargs_cross_entropy_loss & args, + device const float * logits, + device const float * labels, + device float * dst, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], @@ -2315,7 +2315,7 @@ kernel void kernel_cross_entropy_loss( for (int i = tpitg; i < args.ne00; i += tptg){ const float exp_val = exp(logits_row[i] - max_val); lsum += exp_val; - dst_row[i] = exp_val; + dst[i] = exp_val; } threadgroup_barrier(mem_flags::mem_threadgroup); float sum = simd_sum(lsum); @@ -2344,7 +2344,7 @@ kernel void kernel_cross_entropy_loss( loss = simd_sum(loss); } if (tpitg == 0) { - dst[i] = -loss / args.nrows; + dst[row] = -loss / args.nrows; } }