cuda : unroll Q*K^T loop
This commit is contained in:
parent
3b1c4e7673
commit
5b263dd83a
|
|
@ -6571,6 +6571,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
// Q*K^T
|
// Q*K^T
|
||||||
{
|
{
|
||||||
|
#pragma unroll
|
||||||
for (int cc = 0; cc < C/16; ++cc) {
|
for (int cc = 0; cc < C/16; ++cc) {
|
||||||
half16x16_acc mqk[Q16];
|
half16x16_acc mqk[Q16];
|
||||||
for (int j = 0; j < Q16; ++j) {
|
for (int j = 0; j < Q16; ++j) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue