fix rebase issues

This commit is contained in:
Ruben Ortlam 2026-02-12 11:39:28 +01:00
parent 28a3c0b859
commit 3946eb657f
1 changed files with 5 additions and 15 deletions

View File

@ -2766,10 +2766,10 @@ static uint32_t get_fa_scalar_num_rows(uint32_t hsk, uint32_t hsv, FaRows rows,
if (rows == FA_ROWS_1) {
return 1;
} else if (rows == FA_ROWS_SMALL) {
return 8;
return 4;
}
if (hsv >= 192 || (hsv | hsk) & 8 || small_cache || rows == FA_ROWS_2 || rows == FA_ROWS_4 || rows == FA_ROWS_8) {
if (hsv >= 192 || (hsv | hsk) & 8 || small_cache) {
return 8;
}
@ -2792,14 +2792,6 @@ static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path) {
return 0xFFFFFFFF;
}
if (path == FA_VECTOR) {
if (device->vendor_id == VK_VENDOR_ID_AMD && device->subgroup_min_size <= 32 && device->subgroup_max_size >= 32) {
return 32;
} else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) {
return device->subgroup_min_size;
}
}
return device->subgroup_size;
}
@ -2810,8 +2802,6 @@ static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint
return ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128);
case FA_COOPMAT1:
return (Bc / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
case FA_VECTOR:
return device->vendor_id == VK_VENDOR_ID_AMD ? 256 : 128;
default:
if (device->vendor_id == VK_VENDOR_ID_INTEL) {
return 128;
@ -8426,7 +8416,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, bool fp32acc) {
// Needs to be kept up to date on shader changes
const std::array<uint32_t, 2> rows_cols = fa_rows_cols(device, FA_SCALAR, hsk, hsv, clamp, type, rows, small_cache);
const std::array<uint32_t, 2> rows_cols = fa_rows_cols(FA_SCALAR, hsk, hsv, clamp, type, rows, small_cache);
const uint32_t Br = rows_cols[0];
const uint32_t Bc = rows_cols[1];
const uint32_t wg_size = fa_workgroup_size(device, FA_SCALAR, hsk, hsv, rows, Br, Bc);
@ -8625,7 +8615,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// with large hsk/hsv, scalar path may need to use small rows to fit in shared memory
if (path == FA_SCALAR && rows == FA_ROWS_LARGE && !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, 0, k->type, FA_ROWS_LARGE, small_cache, f32acc)) {
rows = FA_ROWS_8;
rows = FA_ROWS_SMALL;
}
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
@ -8694,7 +8684,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// Use a placeholder core count if one isn't available. split_k is a big help for perf.
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
auto rows_cols = fa_rows_cols(ctx->device, path, HSK, HSV, !aligned, k->type, rows, small_cache);
auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, rows, small_cache);
const uint32_t Br = rows_cols[0];
const uint32_t Bc = rows_cols[1];