optimize masksh use

This commit is contained in:
Ruben Ortlam 2026-02-06 13:32:33 +01:00
parent 9b309bbc51
commit c0f419351c
2 changed files with 44 additions and 47 deletions

View File

@ -43,8 +43,7 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in ACC_
return elem;
}
const uint32_t tmpsh_reduction_size = row_split == 1 ? num_subgroups * D_split : 0;
const uint32_t tmpsh_size = tmpsh_reduction_size > 4 ? tmpsh_reduction_size : 4;
const uint32_t tmpsh_size = row_split == 1 ? num_subgroups * D_split : 1;
shared float tmpsh[tmpsh_size];
shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size];
@ -128,49 +127,51 @@ void main() {
uint32_t mask_opt = 0;
uint32_t mask_opt_idx = ~0;
uint32_t mask_opt_bits = 0;
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
if (MASK_ENABLE) {
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
mask_opt_idx = j / 16;
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
}
mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
// skip this block
continue;
}
// Only load if the block is not all zeros
if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
mask_opt_idx = j / 16;
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
}
uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
// skip this block
continue;
}
// Only load if the block is not all zeros
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
float max_mask = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
masksh[c][r] = m;
max_mask = max(max_mask, float(m));
} else {
masksh[c][r] = FLOAT_TYPE(0);
float max_mask = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
masksh[c][r] = m;
max_mask = max(max_mask, float(m));
} else {
masksh[c][r] = FLOAT_TYPE(0);
}
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
}
@ -181,7 +182,6 @@ void main() {
}
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
@ -226,7 +226,6 @@ void main() {
Sf[r][c] += slope[r]*mvf;
}
}
barrier();
}
FLOAT_TYPE Pf[rows_per_thread][cols_per_thread];
@ -286,8 +285,6 @@ void main() {
}
}
}
barrier();
}
// prevent race on tmpsh

View File

@ -153,22 +153,22 @@ void main() {
uint32_t mask_opt = 0;
uint32_t mask_opt_idx = ~0;
uint32_t mask_opt_bits = 0;
f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
[[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) {
mask_cache[idx] = f16vec4(0);
}
if (MASK_ENABLE) {
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
mask_opt_idx = j / 16;
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
}
uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
// skip this block
continue;
@ -329,7 +329,7 @@ void main() {
barrier();
}
if (MASK_ENABLE) {
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) / (Br / 4);
uint32_t r = (idx + tid) % (Br / 4);