fix bugs/typos etc.

This commit is contained in:
Ilia Ilmer 2025-12-16 22:12:15 -05:00
parent fc11fd3ff4
commit 646038f776
No known key found for this signature in database
1 changed files with 6 additions and 6 deletions

View File

@ -2282,10 +2282,10 @@ template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kerne
template <typename T>
kernel void kernel_cross_entropy_loss(
constant ggml_metal_kargs_cross_entropy & args,
device const char * src0,
device const char * src1,
device char * dst,
constant ggml_metal_kargs_cross_entropy_loss & args,
device const float * logits,
device const float * labels,
device float * dst,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
@ -2315,7 +2315,7 @@ kernel void kernel_cross_entropy_loss(
for (int i = tpitg; i < args.ne00; i += tptg){
const float exp_val = exp(logits_row[i] - max_val);
lsum += exp_val;
dst_row[i] = exp_val;
dst[i] = exp_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float sum = simd_sum(lsum);
@ -2344,7 +2344,7 @@ kernel void kernel_cross_entropy_loss(
loss = simd_sum(loss);
}
if (tpitg == 0) {
dst[i] = -loss / args.nrows;
dst[row] = -loss / args.nrows;
}
}