fix bugs/typos etc.

This commit is contained in:
Ilia Ilmer 2025-12-16 22:12:15 -05:00
parent fc11fd3ff4
commit 646038f776
No known key found for this signature in database
1 changed files with 6 additions and 6 deletions

View File

@ -2282,10 +2282,10 @@ template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kerne
template <typename T> template <typename T>
kernel void kernel_cross_entropy_loss( kernel void kernel_cross_entropy_loss(
constant ggml_metal_kargs_cross_entropy & args, constant ggml_metal_kargs_cross_entropy_loss & args,
device const char * src0, device const float * logits,
device const char * src1, device const float * labels,
device char * dst, device float * dst,
threadgroup float * buf [[threadgroup(0)]], threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]], uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]], uint tpitg[[thread_position_in_threadgroup]],
@ -2315,7 +2315,7 @@ kernel void kernel_cross_entropy_loss(
for (int i = tpitg; i < args.ne00; i += tptg){ for (int i = tpitg; i < args.ne00; i += tptg){
const float exp_val = exp(logits_row[i] - max_val); const float exp_val = exp(logits_row[i] - max_val);
lsum += exp_val; lsum += exp_val;
dst_row[i] = exp_val; dst[i] = exp_val;
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
float sum = simd_sum(lsum); float sum = simd_sum(lsum);
@ -2344,7 +2344,7 @@ kernel void kernel_cross_entropy_loss(
loss = simd_sum(loss); loss = simd_sum(loss);
} }
if (tpitg == 0) { if (tpitg == 0) {
dst[i] = -loss / args.nrows; dst[row] = -loss / args.nrows;
} }
} }