vulkan: in flash attention, bounds check against nem1 (don't rely on GGML_KQ_MASK_PAD) (#16316)

This commit is contained in:
Jeff Bolz 2025-10-03 03:33:08 -05:00 committed by GitHub
parent 136bda78c5
commit e308efda8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 27 additions and 12 deletions

View File

@ -2614,8 +2614,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t D_lsb = D ^ (D & (D-1)); const uint32_t D_lsb = D ^ (D & (D-1));
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split}; return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
}; };
@ -7457,8 +7455,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) { if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) {
aligned = false; aligned = false;
} }
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;

View File

@ -153,12 +153,13 @@ void main() {
} }
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc; uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc; uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) { if (idx + tid < Bc * Br) {
if (!KV_bounds_check || j * Bc + c < KV) { if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
} else { } else {
masksh[c][r] = float(0); masksh[c][r] = float(0);

View File

@ -201,11 +201,13 @@ void main() {
} }
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc; uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc; uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
if (!KV_bounds_check || j * Bc + c < KV) { if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)])); sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
} }
} }

View File

@ -154,15 +154,31 @@ void main() {
} }
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv; if (nem1_bounds_check) {
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv); coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
} else {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
// Don't clamp against nem1 when GQA is enabled
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
}
} }
// Clear padding elements to -inf, so they don't contribute to rowmax // Clear padding elements to -inf, so they don't contribute to rowmax