vulkan: make FA mask/softcap enables spec constants (#19309)

* vulkan: make FA mask/softcap enables spec constants

* don't specialize for sinks

* bump timeout a little bit
This commit is contained in:
Jeff Bolz 2026-02-06 01:49:58 -06:00 committed by GitHub
parent 7fcf1ef45d
commit f9bd518a6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 45 additions and 38 deletions

View File

@ -468,7 +468,7 @@ jobs:
export GGML_VK_VISIBLE_DEVICES=0
export GGML_VK_DISABLE_F16=1
# This is using llvmpipe and runs slower than other backends
ctest -L main --verbose --timeout 4200
ctest -L main --verbose --timeout 4800
ubuntu-24-cmake-webgpu:
runs-on: ubuntu-24.04

View File

@ -402,19 +402,19 @@ enum FaCodePath {
};
struct vk_fa_pipeline_state {
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, bool use_mask_opt)
: HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), use_mask_opt(use_mask_opt) {}
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags)
: HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {}
uint32_t HSK, HSV;
bool small_rows, small_cache;
FaCodePath path;
bool aligned;
bool f32acc;
bool use_mask_opt;
uint32_t flags;
bool operator<(const vk_fa_pipeline_state &b) const {
return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt) <
std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.use_mask_opt);
return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags) <
std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags);
}
};
@ -3193,7 +3193,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
};
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, bool use_mask_opt) -> std::vector<uint32_t> {
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, uint32_t flags) -> std::vector<uint32_t> {
// For large number of rows, 128 invocations seems to work best.
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
// can't use 256 for D==80.
@ -3225,7 +3225,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
// AMD prefers loading K directly from global memory
const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, use_mask_opt};
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags};
};
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
@ -3237,19 +3237,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
FaCodePath path = fa.first.path; \
bool aligned = fa.first.aligned; \
bool f32acc = fa.first.f32acc; \
bool use_mask_opt = fa.first.use_mask_opt; \
uint32_t flags = fa.first.flags; \
if (path == FAPATH) { \
if (aligned) { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} \
} else { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} \
} \
} \
@ -8595,10 +8595,26 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
if (logit_softcap != 0) {
scale /= logit_softcap;
}
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt);
uint32_t flags = (use_mask_opt ? 1 : 0) |
(mask != nullptr ? 2 : 0) |
(logit_softcap != 0 ? 4 : 0);
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags);
vk_pipeline pipeline = nullptr;
@ -8678,18 +8694,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
if (logit_softcap != 0) {
scale /= logit_softcap;
}
const uint32_t n_head_kv = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@ -8703,7 +8707,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;
vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf;
uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | n_head_log2;
if (use_mask_opt)
{

View File

@ -127,7 +127,7 @@ void main() {
continue;
}
// Only load if the block is not all zeros
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) {
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
@ -181,7 +181,7 @@ void main() {
}
}
if (p.logit_softcap != 0.0f) {
if (LOGIT_SOFTCAP) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
@ -189,7 +189,7 @@ void main() {
}
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) {
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float mvf = masksh[c * cols_per_iter + col_tid][r];

View File

@ -10,7 +10,11 @@ layout (constant_id = 5) const uint32_t Clamp = 0;
layout (constant_id = 6) const uint32_t D_split = 16;
layout (constant_id = 7) const uint32_t SubGroupSize = 32;
layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;
layout (constant_id = 9) const bool USE_MASK_OPT = false;
layout (constant_id = 9) const uint32_t Flags = 0;
const bool USE_MASK_OPT = (Flags & 1) != 0;
const bool MASK_ENABLE = (Flags & 2) != 0;
const bool LOGIT_SOFTCAP = (Flags & 4) != 0;
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
const uint32_t HSK_pad = (HSK + 15) & ~15;
@ -60,7 +64,6 @@ layout (push_constant) uniform parameter {
} p;
#define SINK_ENABLE_BIT (1<<24)
#define MASK_ENABLE_BIT (1<<16)
#define N_LOG2_MASK 0xFFFF
layout (binding = 4) readonly buffer S {float data_s[];};

View File

@ -160,7 +160,7 @@ void main() {
mask_cache[idx] = f16vec4(0);
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
if (MASK_ENABLE) {
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
mask_opt_idx = j / 16;
@ -303,7 +303,7 @@ void main() {
coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);
barrier();
if (p.logit_softcap != 0.0f) {
if (LOGIT_SOFTCAP) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) / (Br / 4);
uint32_t r = (idx + tid) % (Br / 4);
@ -314,7 +314,7 @@ void main() {
barrier();
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
if (MASK_ENABLE) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) / (Br / 4);
uint32_t r = (idx + tid) % (Br / 4);

View File

@ -155,7 +155,7 @@ void main() {
for (uint32_t j = start_j; j < end_j; ++j) {
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
if (MASK_ENABLE) {
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
mask_opt_idx = j / 16;
@ -197,14 +197,14 @@ void main() {
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
S = coopMatMulAdd(Qf16, K_T, S);
if (p.logit_softcap != 0.0f) {
if (LOGIT_SOFTCAP) {
[[unroll]]
for (int k = 0; k < S.length(); ++k) {
S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
}
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
if (MASK_ENABLE) {
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
}