fix shmem support function
This commit is contained in:
parent
3ed9183ac9
commit
28a3c0b859
|
|
@ -2803,6 +2803,26 @@ static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path) {
|
|||
return device->subgroup_size;
|
||||
}
|
||||
|
||||
static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, FaRows rows, uint32_t Br, uint32_t Bc) {
|
||||
const uint32_t D = std::max(hsk, hsv);
|
||||
switch (path) {
|
||||
case FA_COOPMAT2:
|
||||
return ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128);
|
||||
case FA_COOPMAT1:
|
||||
return (Bc / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
|
||||
case FA_VECTOR:
|
||||
return device->vendor_id == VK_VENDOR_ID_AMD ? 256 : 128;
|
||||
default:
|
||||
if (device->vendor_id == VK_VENDOR_ID_INTEL) {
|
||||
return 128;
|
||||
} else if (device->subgroup_size > 32 && Br < 4) {
|
||||
return device->subgroup_size * 2;
|
||||
} else {
|
||||
return device->subgroup_size * 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) {
|
||||
GGML_UNUSED(clamp);
|
||||
|
||||
|
|
@ -3220,29 +3240,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
const uint32_t D = (hsk|hsv);
|
||||
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache);
|
||||
|
||||
uint32_t wg_size;
|
||||
switch (path) {
|
||||
case FA_COOPMAT2:
|
||||
wg_size = ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128);
|
||||
break;
|
||||
case FA_COOPMAT1:
|
||||
if (disable_subgroups) {
|
||||
wg_size = 128;
|
||||
} else {
|
||||
wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
|
||||
}
|
||||
break;
|
||||
default:
|
||||
if (disable_subgroups) {
|
||||
wg_size = 128;
|
||||
} else if (device->subgroup_size > 32 && rows_cols[0] < 4) {
|
||||
wg_size = device->subgroup_size * 2;
|
||||
} else {
|
||||
wg_size = device->subgroup_size * 4;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
const uint32_t wg_size = fa_workgroup_size(device, path, hsk, hsv, rows, rows_cols[0], rows_cols[1]);
|
||||
const uint32_t subgroup_size = fa_subgroup_size(device, path);
|
||||
|
||||
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
||||
|
|
@ -8426,21 +8424,29 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, FaRows rows, bool small_cache) {
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, bool fp32acc) {
|
||||
// Needs to be kept up to date on shader changes
|
||||
GGML_UNUSED(hsv);
|
||||
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
||||
const uint32_t Br = get_fa_scalar_num_rows(hsk, hsv, rows, small_cache);
|
||||
const uint32_t Bc = scalar_flash_attention_Bc;
|
||||
const std::array<uint32_t, 2> rows_cols = fa_rows_cols(device, FA_SCALAR, hsk, hsv, clamp, type, rows, small_cache);
|
||||
const uint32_t Br = rows_cols[0];
|
||||
const uint32_t Bc = rows_cols[1];
|
||||
const uint32_t wg_size = fa_workgroup_size(device, FA_SCALAR, hsk, hsv, rows, Br, Bc);
|
||||
|
||||
const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
||||
const uint32_t acc_type_size = !fp32acc ? sizeof(ggml_fp16_t) : sizeof(float);
|
||||
|
||||
// tmpsh is overestimated slightly
|
||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||
const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
|
||||
const uint32_t tmpshv4 = wg_size * 4 * acc_type_size;
|
||||
|
||||
const uint32_t masksh = Bc * Br * sizeof(float);
|
||||
const uint32_t masksh = Bc * (Br + 1) * float_type_size;
|
||||
|
||||
const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
|
||||
const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
|
||||
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
|
||||
const uint32_t D = std::max(hsk, hsv);
|
||||
const bool shmem_staging = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256;
|
||||
const uint32_t kvsh = shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
|
||||
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
|
||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
|
||||
|
|
@ -8569,6 +8575,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
uint32_t workgroups_z = (uint32_t)neq3;
|
||||
|
||||
const bool small_cache = nek1 < 1024;
|
||||
const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32;
|
||||
|
||||
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
|
||||
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
|
||||
|
|
@ -8617,8 +8624,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
}
|
||||
|
||||
// with large hsk/hsv, scalar path may need to use small rows to fit in shared memory
|
||||
if (path == FA_SCALAR && rows == FA_ROWS_LARGE && !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, FA_ROWS_LARGE, small_cache)) {
|
||||
rows = FA_ROWS_SMALL;
|
||||
if (path == FA_SCALAR && rows == FA_ROWS_LARGE && !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, 0, k->type, FA_ROWS_LARGE, small_cache, f32acc)) {
|
||||
rows = FA_ROWS_8;
|
||||
}
|
||||
|
||||
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
||||
|
|
@ -8643,8 +8650,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
aligned = false;
|
||||
}
|
||||
|
||||
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
|
|
|
|||
Loading…
Reference in New Issue