fixes
This commit is contained in:
parent
e3bba64e82
commit
07afb5128f
|
|
@ -2766,7 +2766,7 @@ static uint32_t get_fa_scalar_num_rows(uint32_t hsk, uint32_t hsv, FaRows rows,
|
|||
if (rows == FA_ROWS_1) {
|
||||
return 1;
|
||||
} else if (rows == FA_ROWS_SMALL) {
|
||||
return 4;
|
||||
return 8;
|
||||
}
|
||||
|
||||
if (hsv >= 192) {
|
||||
|
|
|
|||
|
|
@ -146,6 +146,7 @@ void main() {
|
|||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
float max_mask = NEG_FLT_MAX_OVER_2;
|
||||
barrier();
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
|
|
|
|||
Loading…
Reference in New Issue