fix bugs/typos etc.
This commit is contained in:
parent
fc11fd3ff4
commit
646038f776
|
|
@ -2282,10 +2282,10 @@ template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kerne
|
|||
|
||||
template <typename T>
|
||||
kernel void kernel_cross_entropy_loss(
|
||||
constant ggml_metal_kargs_cross_entropy & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
constant ggml_metal_kargs_cross_entropy_loss & args,
|
||||
device const float * logits,
|
||||
device const float * labels,
|
||||
device float * dst,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
|
|
@ -2315,7 +2315,7 @@ kernel void kernel_cross_entropy_loss(
|
|||
for (int i = tpitg; i < args.ne00; i += tptg){
|
||||
const float exp_val = exp(logits_row[i] - max_val);
|
||||
lsum += exp_val;
|
||||
dst_row[i] = exp_val;
|
||||
dst[i] = exp_val;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
float sum = simd_sum(lsum);
|
||||
|
|
@ -2344,7 +2344,7 @@ kernel void kernel_cross_entropy_loss(
|
|||
loss = simd_sum(loss);
|
||||
}
|
||||
if (tpitg == 0) {
|
||||
dst[i] = -loss / args.nrows;
|
||||
dst[row] = -loss / args.nrows;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue