metal : fix l2 norm scale (#20493)

This commit is contained in:
Georgi Gerganov 2026-03-13 11:43:20 +02:00 committed by GitHub
parent 983df142a9
commit 73c9eb8ced
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View File

@ -1156,7 +1156,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_RWKV_WKV7:
return true;
case GGML_OP_GATED_DELTA_NET:
return op->src[2]->ne[0] % 32 == 0;
return has_simdgroup_reduction && op->src[2]->ne[0] % 32 == 0;
case GGML_OP_SOLVE_TRI:
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:

View File

@ -3006,7 +3006,7 @@ kernel void kernel_l2_norm_impl(
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float scale = 1.0f/sqrt(max(sumf, args.eps));
const float scale = 1.0f/max(sqrt(sumf), args.eps);
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
y[i00] = x[i00] * scale;