tmpsh size fix

This commit is contained in:
Ruben Ortlam 2026-02-14 11:43:31 +01:00
parent 0b4b0d2e57
commit 02ccf81496
1 changed files with 16 additions and 16 deletions

View File

@ -35,7 +35,7 @@ layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
layout (binding = 3) readonly buffer M {float16_t data_m[];};
// If SubGroupSize is set to 0xFFFFFFFF then only use shmem reductions
const uint32_t tmpsh_size = (SubGroupSize != SUBGROUPS_DISABLED) ? (row_split == 1 ? num_subgroups * D_split : 1) : WorkGroupSize;
const uint32_t tmpsh_size = (SubGroupSize != SUBGROUPS_DISABLED) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize;
shared float tmpsh[tmpsh_size];
shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size];
@ -143,7 +143,7 @@ void main() {
if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
float max_mask = NEG_FLT_MAX_OVER_2;
float max_mask = NEG_FLT_MAX_OVER_2;
barrier();
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
@ -152,25 +152,25 @@ void main() {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
masksh[c * masksh_stride + r] = m;
max_mask = max(max_mask, float(m));
max_mask = max(max_mask, float(m));
} else {
masksh[c * masksh_stride + r] = FLOAT_TYPE(0);
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
}