cuda : unroll Q*K^T loop
This commit is contained in:
parent
3b1c4e7673
commit
5b263dd83a
|
|
@ -6571,6 +6571,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
// Q*K^T
|
||||
{
|
||||
#pragma unroll
|
||||
for (int cc = 0; cc < C/16; ++cc) {
|
||||
half16x16_acc mqk[Q16];
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue