fix rebase issues
This commit is contained in:
parent
28a3c0b859
commit
3946eb657f
|
|
@ -2766,10 +2766,10 @@ static uint32_t get_fa_scalar_num_rows(uint32_t hsk, uint32_t hsv, FaRows rows,
|
|||
if (rows == FA_ROWS_1) {
|
||||
return 1;
|
||||
} else if (rows == FA_ROWS_SMALL) {
|
||||
return 8;
|
||||
return 4;
|
||||
}
|
||||
|
||||
if (hsv >= 192 || (hsv | hsk) & 8 || small_cache || rows == FA_ROWS_2 || rows == FA_ROWS_4 || rows == FA_ROWS_8) {
|
||||
if (hsv >= 192 || (hsv | hsk) & 8 || small_cache) {
|
||||
return 8;
|
||||
}
|
||||
|
||||
|
|
@ -2792,14 +2792,6 @@ static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path) {
|
|||
return 0xFFFFFFFF;
|
||||
}
|
||||
|
||||
if (path == FA_VECTOR) {
|
||||
if (device->vendor_id == VK_VENDOR_ID_AMD && device->subgroup_min_size <= 32 && device->subgroup_max_size >= 32) {
|
||||
return 32;
|
||||
} else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) {
|
||||
return device->subgroup_min_size;
|
||||
}
|
||||
}
|
||||
|
||||
return device->subgroup_size;
|
||||
}
|
||||
|
||||
|
|
@ -2810,8 +2802,6 @@ static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint
|
|||
return ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128);
|
||||
case FA_COOPMAT1:
|
||||
return (Bc / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
|
||||
case FA_VECTOR:
|
||||
return device->vendor_id == VK_VENDOR_ID_AMD ? 256 : 128;
|
||||
default:
|
||||
if (device->vendor_id == VK_VENDOR_ID_INTEL) {
|
||||
return 128;
|
||||
|
|
@ -8426,7 +8416,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, bool fp32acc) {
|
||||
// Needs to be kept up to date on shader changes
|
||||
const std::array<uint32_t, 2> rows_cols = fa_rows_cols(device, FA_SCALAR, hsk, hsv, clamp, type, rows, small_cache);
|
||||
const std::array<uint32_t, 2> rows_cols = fa_rows_cols(FA_SCALAR, hsk, hsv, clamp, type, rows, small_cache);
|
||||
const uint32_t Br = rows_cols[0];
|
||||
const uint32_t Bc = rows_cols[1];
|
||||
const uint32_t wg_size = fa_workgroup_size(device, FA_SCALAR, hsk, hsv, rows, Br, Bc);
|
||||
|
|
@ -8625,7 +8615,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
|
||||
// with large hsk/hsv, scalar path may need to use small rows to fit in shared memory
|
||||
if (path == FA_SCALAR && rows == FA_ROWS_LARGE && !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, 0, k->type, FA_ROWS_LARGE, small_cache, f32acc)) {
|
||||
rows = FA_ROWS_8;
|
||||
rows = FA_ROWS_SMALL;
|
||||
}
|
||||
|
||||
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
||||
|
|
@ -8694,7 +8684,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
// Use a placeholder core count if one isn't available. split_k is a big help for perf.
|
||||
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
|
||||
|
||||
auto rows_cols = fa_rows_cols(ctx->device, path, HSK, HSV, !aligned, k->type, rows, small_cache);
|
||||
auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, rows, small_cache);
|
||||
const uint32_t Br = rows_cols[0];
|
||||
const uint32_t Bc = rows_cols[1];
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue