fuse lf accumulation, pf and v accumulation into a loop

This commit is contained in:
Ruben Ortlam 2026-02-08 08:32:00 +01:00
parent ca5ec63cfb
commit 50a420e044
1 changed files with 9 additions and 16 deletions

View File

@ -236,7 +236,6 @@ void main() {
}
}
FLOAT_TYPE Pf[rows_per_thread][cols_per_thread];
float eMf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
float rowmaxf = NEG_FLT_MAX_OVER_2;
@ -252,21 +251,8 @@ void main() {
// P = e^(S - M)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf, Moldf);
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Pf[r][c] = FLOAT_TYPE(exp(float(Sf[r][c]) - Mf[r]));
}
eMf[r] = exp(Moldf - Mf[r]);
// Compute sum across row of P
float rowsumf = 0.0;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowsumf += Pf[r][c];
}
Lf[r] = eMf[r]*Lf[r] + rowsumf;
Lf[r] = eMf[r]*Lf[r];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
@ -279,6 +265,13 @@ void main() {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
FLOAT_TYPE Pf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Pf[r] = FLOAT_TYPE(exp(float(Sf[r][c]) - Mf[r]));
Lf[r] += Pf[r];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
@ -289,7 +282,7 @@ void main() {
FLOAT_TYPEV4 Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] += ACC_TYPEV4(Pf[r][c] * Vf);
Of[r][d] += ACC_TYPEV4(Pf[r] * Vf);
}
}
}