device tuning
This commit is contained in:
parent
dd92b1f8d5
commit
0b4b0d2e57
|
|
@ -2762,12 +2762,16 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
|
|||
// number of rows/cols for flash attention shader
|
||||
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
||||
|
||||
static uint32_t get_fa_scalar_num_rows(uint32_t hsk, uint32_t hsv, FaRows rows, bool small_cache) {
|
||||
static uint32_t get_fa_scalar_num_rows(const vk_device& device, uint32_t hsk, uint32_t hsv, FaRows rows, bool small_cache) {
|
||||
if (rows == FA_ROWS_1) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (rows == FA_ROWS_SMALL || hsv >= 192 || (hsv | hsk) & 8 || small_cache) {
|
||||
if (
|
||||
rows == FA_ROWS_SMALL || hsv >= 192 || (hsv | hsk) & 8 || small_cache ||
|
||||
(device->architecture == AMD_GCN && hsk <= 64) ||
|
||||
(device->vendor_id == VK_VENDOR_ID_INTEL)
|
||||
) {
|
||||
return 8;
|
||||
}
|
||||
|
||||
|
|
@ -2785,12 +2789,12 @@ static bool fa_disable_subgroups(const vk_device& device, FaCodePath path) {
|
|||
return device->vendor_id == VK_VENDOR_ID_INTEL && path == FA_SCALAR;
|
||||
}
|
||||
|
||||
static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path) {
|
||||
static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path, FaRows rows) {
|
||||
if (fa_disable_subgroups(device, path)) {
|
||||
return 0xFFFFFFFF;
|
||||
}
|
||||
|
||||
if (path == FA_SCALAR && device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) {
|
||||
if (path == FA_SCALAR && device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && rows == FA_ROWS_1) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
|
|
@ -2799,14 +2803,14 @@ static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path) {
|
|||
|
||||
static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, FaRows rows, uint32_t Br, uint32_t Bc) {
|
||||
const uint32_t D = std::max(hsk, hsv);
|
||||
const uint32_t subgroup_size = fa_disable_subgroups(device, path) ? 32 : fa_subgroup_size(device, path);
|
||||
const uint32_t subgroup_size = fa_disable_subgroups(device, path) ? 32 : fa_subgroup_size(device, path, rows);
|
||||
switch (path) {
|
||||
case FA_COOPMAT2:
|
||||
return ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128);
|
||||
case FA_COOPMAT1:
|
||||
return (Bc / 16) * subgroup_size; // enough subgroups for Bc/MatBc
|
||||
default:
|
||||
if (subgroup_size > 32 && Br < 4) {
|
||||
if (subgroup_size > 32 && (Br < 4 || hsk < 64)) {
|
||||
return subgroup_size * 2;
|
||||
} else {
|
||||
return subgroup_size * 4;
|
||||
|
|
@ -2814,7 +2818,7 @@ static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint
|
|||
}
|
||||
}
|
||||
|
||||
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) {
|
||||
static std::array<uint32_t, 2> fa_rows_cols(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) {
|
||||
GGML_UNUSED(clamp);
|
||||
|
||||
if (path == FA_SCALAR) {
|
||||
|
|
@ -2822,9 +2826,9 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
|
|||
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
|
||||
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
|
||||
// But this only applies to row_split=1, meaning FA_ROWS_1
|
||||
return {get_fa_scalar_num_rows(hsk, hsv, rows, small_cache), 64};
|
||||
return {get_fa_scalar_num_rows(device, hsk, hsv, rows, small_cache), 64};
|
||||
} else {
|
||||
return {get_fa_scalar_num_rows(hsk, hsv, rows, small_cache), 32};
|
||||
return {get_fa_scalar_num_rows(device, hsk, hsv, rows, small_cache), 32};
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2848,8 +2852,8 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
|
|||
return {64, 64};
|
||||
}
|
||||
|
||||
static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, FaRows rows, bool small_cache) {
|
||||
return fa_rows_cols(path, hsk, hsv, 0, type, rows, small_cache)[1];
|
||||
static uint32_t fa_align(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, FaRows rows, bool small_cache) {
|
||||
return fa_rows_cols(device, path, hsk, hsv, 0, type, rows, small_cache)[1];
|
||||
}
|
||||
|
||||
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
|
||||
|
|
@ -3217,7 +3221,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
};
|
||||
|
||||
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) -> std::array<uint32_t, 3> {
|
||||
return {fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache)[0], 1, 1};
|
||||
return {fa_rows_cols(device, path, hsk, hsv, clamp, type, rows, small_cache)[0], 1, 1};
|
||||
};
|
||||
|
||||
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, uint32_t flags) -> std::vector<uint32_t> {
|
||||
|
|
@ -3227,10 +3231,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
// For scalar, use 128 (arbitrary)
|
||||
// The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
|
||||
const uint32_t D = (hsk|hsv);
|
||||
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache);
|
||||
auto rows_cols = fa_rows_cols(device, path, hsk, hsv, clamp, type, rows, small_cache);
|
||||
|
||||
const uint32_t wg_size = fa_workgroup_size(device, path, hsk, hsv, rows, rows_cols[0], rows_cols[1]);
|
||||
const uint32_t subgroup_size = fa_subgroup_size(device, path);
|
||||
const uint32_t subgroup_size = fa_subgroup_size(device, path, rows);
|
||||
|
||||
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
||||
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
||||
|
|
@ -3256,13 +3260,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
bool f32acc = fa.first.f32acc; \
|
||||
uint32_t flags = fa.first.flags; \
|
||||
bool fa_ds = fa_disable_subgroups(device, path); \
|
||||
uint32_t fa_sgs = fa_subgroup_size(device, path); \
|
||||
uint32_t fa_sgs = fa_subgroup_size(device, path, rows); \
|
||||
if (path == FAPATH) { \
|
||||
if (aligned) { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(device, FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(device, FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
||||
} \
|
||||
} else { \
|
||||
if (f32acc) { \
|
||||
|
|
@ -8420,7 +8424,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, bool fp32acc) {
|
||||
// Needs to be kept up to date on shader changes
|
||||
const std::array<uint32_t, 2> rows_cols = fa_rows_cols(FA_SCALAR, hsk, hsv, clamp, type, rows, small_cache);
|
||||
const std::array<uint32_t, 2> rows_cols = fa_rows_cols(device, FA_SCALAR, hsk, hsv, clamp, type, rows, small_cache);
|
||||
const uint32_t Br = rows_cols[0];
|
||||
const uint32_t Bc = rows_cols[1];
|
||||
const uint32_t wg_size = fa_workgroup_size(device, FA_SCALAR, hsk, hsv, rows, Br, Bc);
|
||||
|
|
@ -8451,7 +8455,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
|
|||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
|
||||
// Needs to be kept up to date on shader changes
|
||||
GGML_UNUSED(hsv);
|
||||
const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, FA_ROWS_LARGE, false);
|
||||
const auto rows_cols = fa_rows_cols(device, FA_COOPMAT1, hsk, hsv, 0, kv_type, FA_ROWS_LARGE, false);
|
||||
const uint32_t Br = rows_cols[0];
|
||||
const uint32_t Bc = rows_cols[1];
|
||||
|
||||
|
|
@ -8578,7 +8582,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
case FA_SCALAR:
|
||||
case FA_COOPMAT1:
|
||||
// We may switch from coopmat1 to scalar, so use the scalar limit for both
|
||||
max_gqa = get_fa_scalar_num_rows(HSK, HSV, FA_ROWS_LARGE, small_cache);
|
||||
max_gqa = get_fa_scalar_num_rows(ctx->device, HSK, HSV, FA_ROWS_LARGE, small_cache);
|
||||
break;
|
||||
case FA_COOPMAT2:
|
||||
max_gqa = flash_attention_num_small_rows;
|
||||
|
|
@ -8606,14 +8610,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
rows = FA_ROWS_LARGE;
|
||||
}
|
||||
|
||||
// coopmat1 does not actually support "small rows" (it needs 16 rows).
|
||||
// So use scalar instead.
|
||||
if (rows != FA_ROWS_LARGE && path == FA_COOPMAT1) {
|
||||
path = FA_SCALAR;
|
||||
}
|
||||
|
||||
// scalar is faster than coopmat2 when N==1
|
||||
if (rows == FA_ROWS_1 && path == FA_COOPMAT2) {
|
||||
if (rows == FA_ROWS_1 && (path == FA_COOPMAT1 || path == FA_COOPMAT2)) {
|
||||
path = FA_SCALAR;
|
||||
}
|
||||
|
||||
|
|
@ -8634,7 +8632,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
v_stride /= 4;
|
||||
}
|
||||
|
||||
uint32_t alignment = fa_align(path, HSK, HSV, k->type, rows, small_cache);
|
||||
uint32_t alignment = fa_align(ctx->device, path, HSK, HSV, k->type, rows, small_cache);
|
||||
bool aligned = (KV % alignment) == 0 &&
|
||||
// the "aligned" shader variant will forcibly align strides, for performance
|
||||
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
|
||||
|
|
@ -8691,7 +8689,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
// Use a placeholder core count if one isn't available. split_k is a big help for perf.
|
||||
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16;
|
||||
|
||||
auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, rows, small_cache);
|
||||
auto rows_cols = fa_rows_cols(ctx->device, path, HSK, HSV, !aligned, k->type, rows, small_cache);
|
||||
const uint32_t Br = rows_cols[0];
|
||||
const uint32_t Bc = rows_cols[1];
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue