vulkan: fix flash attention dot product precision (#20589)

This commit is contained in:
Ruben Ortlam 2026-03-16 10:45:49 +01:00 committed by GitHub
parent de8f01c2d7
commit 46dba9fce8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 2 deletions

View File

@ -245,7 +245,7 @@ void main() {
#endif
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));
Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf));
}
}
}
@ -270,7 +270,7 @@ void main() {
#endif
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf));
Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf));
}
}
}