vulkan: fix OOB check in flash_attn_mask_opt (#20296)

This commit is contained in:
Jeff Bolz 2026-03-12 00:35:49 -05:00 committed by GitHub
parent 5866e3bbc8
commit aa429cf507
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 60 additions and 40 deletions

View File

@ -8840,7 +8840,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
} }
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively. // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768; bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16;
vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc, vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
mask != nullptr, use_mask_opt, logit_softcap != 0); mask != nullptr, use_mask_opt, logit_softcap != 0);

View File

@ -33,28 +33,11 @@ layout (push_constant) uniform parameter {
shared float minsh[NUM_SUBGROUPS]; shared float minsh[NUM_SUBGROUPS];
shared float maxsh[NUM_SUBGROUPS]; shared float maxsh[NUM_SUBGROUPS];
// For each Br x Bc block of the mask (input) buffer, read all values and check
// if it's all -inf or all zero. Write out a two-bit code indicating which it is
// (or zero for neither). Each workgroup processes 16 tiles and writes out a
// 32-bit result mask.
//
// TODO: This is a lot of work per workgroup, might make sense to split this into
// more workgroups in the future.
void main() {
// Each workgroup handles a row
const uint tid = gl_LocalInvocationIndex;
const uint i0 = gl_WorkGroupID.x;
const uint i1 = gl_WorkGroupID.y;
const uint i2 = gl_WorkGroupID.z % nem2;
const uint i3 = gl_WorkGroupID.z / nem2;
float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF); float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);
uint result = 0; void loadvec4(inout uint result, const uint i0, const uint i1, const uint i2, const uint i3, const bool need_bounds_check) {
const uint tid = gl_LocalInvocationIndex;
// Fast path for fully in-bounds blocks where we can do f16vec4 loads
if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 &&
((Br * Bc) % (BLOCK_SIZE * 4)) == 0) {
[[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) { [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
float min_v = FLT_MAX_OVER_2; float min_v = FLT_MAX_OVER_2;
float max_v = -FLT_MAX_OVER_2; float max_v = -FLT_MAX_OVER_2;
@ -66,11 +49,21 @@ void main() {
j0 += (i0 * 16 + block_x) * Bc; j0 += (i0 * 16 + block_x) * Bc;
j1 += i1 * Br; j1 += i1 * Br;
if (!need_bounds_check || j0 + 3 < nem0) {
vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]); vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
[[unroll]] for (int c = 0; c < 4; ++c) { [[unroll]] for (int c = 0; c < 4; ++c) {
min_v = min(min_v, f[c]); min_v = min(min_v, f[c]);
max_v = max(max_v, f[c]); max_v = max(max_v, f[c]);
} }
} else {
[[unroll]] for (int c = 0; c < 4; ++c) {
if (j0 + c < nem0) {
float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]);
min_v = min(min_v, f);
max_v = max(max_v, f);
}
}
}
} }
min_v = subgroupMin(min_v); min_v = subgroupMin(min_v);
max_v = subgroupMax(max_v); max_v = subgroupMax(max_v);
@ -93,6 +86,33 @@ void main() {
} }
barrier(); barrier();
} }
}
// For each Br x Bc block of the mask (input) buffer, read all values and check
// if it's all -inf or all zero. Write out a two-bit code indicating which it is
// (or zero for neither). Each workgroup processes 16 tiles and writes out a
// 32-bit result mask.
//
// TODO: This is a lot of work per workgroup, might make sense to split this into
// more workgroups in the future.
void main() {
// Each workgroup handles a row
const uint tid = gl_LocalInvocationIndex;
const uint i0 = gl_WorkGroupID.x;
const uint i1 = gl_WorkGroupID.y;
const uint i2 = gl_WorkGroupID.z % nem2;
const uint i3 = gl_WorkGroupID.z / nem2;
uint result = 0;
// Fast path for fully in-bounds blocks where we can do f16vec4 loads
if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 &&
((Br * Bc) % (BLOCK_SIZE * 4)) == 0) {
if ((i0 + 1) * 16 * Bc <= nem0) {
loadvec4(result, i0, i1, i2, i3, false);
} else {
loadvec4(result, i0, i1, i2, i3, true);
}
} else { } else {
[[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) { [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
float min_v = FLT_MAX_OVER_2; float min_v = FLT_MAX_OVER_2;