fix shmem support function

This commit is contained in:
Ruben Ortlam 2026-02-12 09:02:30 +01:00
parent 3ed9183ac9
commit 28a3c0b859
1 changed files with 41 additions and 36 deletions

View File

@ -2803,6 +2803,26 @@ static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path) {
return device->subgroup_size;
}
static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, FaRows rows, uint32_t Br, uint32_t Bc) {
const uint32_t D = std::max(hsk, hsv);
switch (path) {
case FA_COOPMAT2:
return ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128);
case FA_COOPMAT1:
return (Bc / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
case FA_VECTOR:
return device->vendor_id == VK_VENDOR_ID_AMD ? 256 : 128;
default:
if (device->vendor_id == VK_VENDOR_ID_INTEL) {
return 128;
} else if (device->subgroup_size > 32 && Br < 4) {
return device->subgroup_size * 2;
} else {
return device->subgroup_size * 4;
}
}
}
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) {
GGML_UNUSED(clamp);
@ -3220,29 +3240,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t D = (hsk|hsv);
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache);
uint32_t wg_size;
switch (path) {
case FA_COOPMAT2:
wg_size = ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128);
break;
case FA_COOPMAT1:
if (disable_subgroups) {
wg_size = 128;
} else {
wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
}
break;
default:
if (disable_subgroups) {
wg_size = 128;
} else if (device->subgroup_size > 32 && rows_cols[0] < 4) {
wg_size = device->subgroup_size * 2;
} else {
wg_size = device->subgroup_size * 4;
}
break;
}
const uint32_t wg_size = fa_workgroup_size(device, path, hsk, hsv, rows, rows_cols[0], rows_cols[1]);
const uint32_t subgroup_size = fa_subgroup_size(device, path);
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
@ -8426,21 +8424,29 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, FaRows rows, bool small_cache) {
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, bool fp32acc) {
// Needs to be kept up to date on shader changes
GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = get_fa_scalar_num_rows(hsk, hsv, rows, small_cache);
const uint32_t Bc = scalar_flash_attention_Bc;
const std::array<uint32_t, 2> rows_cols = fa_rows_cols(device, FA_SCALAR, hsk, hsv, clamp, type, rows, small_cache);
const uint32_t Br = rows_cols[0];
const uint32_t Bc = rows_cols[1];
const uint32_t wg_size = fa_workgroup_size(device, FA_SCALAR, hsk, hsv, rows, Br, Bc);
const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
const uint32_t acc_type_size = !fp32acc ? sizeof(ggml_fp16_t) : sizeof(float);
// tmpsh is overestimated slightly
const uint32_t tmpsh = wg_size * sizeof(float);
const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
const uint32_t tmpshv4 = wg_size * 4 * acc_type_size;
const uint32_t masksh = Bc * Br * sizeof(float);
const uint32_t masksh = Bc * (Br + 1) * float_type_size;
const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
const uint32_t D = std::max(hsk, hsv);
const bool shmem_staging = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256;
const uint32_t kvsh = shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
@ -8569,6 +8575,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
uint32_t workgroups_z = (uint32_t)neq3;
const bool small_cache = nek1 < 1024;
const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32;
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
@ -8617,8 +8624,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
// with large hsk/hsv, scalar path may need to use small rows to fit in shared memory
if (path == FA_SCALAR && rows == FA_ROWS_LARGE && !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, FA_ROWS_LARGE, small_cache)) {
rows = FA_ROWS_SMALL;
if (path == FA_SCALAR && rows == FA_ROWS_LARGE && !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, 0, k->type, FA_ROWS_LARGE, small_cache, f32acc)) {
rows = FA_ROWS_8;
}
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
@ -8643,8 +8650,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
aligned = false;
}
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;