fix regressions

This commit is contained in:
Ruben Ortlam 2026-02-14 06:45:58 +01:00
parent 9f9a8743c4
commit dd92b1f8d5
2 changed files with 49 additions and 24 deletions

View File

@ -2806,9 +2806,7 @@ static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint
case FA_COOPMAT1:
return (Bc / 16) * subgroup_size; // enough subgroups for Bc/MatBc
default:
if (device->vendor_id == VK_VENDOR_ID_INTEL) {
return 128;
} else if (subgroup_size > 32 && Br < 4) {
if (subgroup_size > 32 && Br < 4) {
return subgroup_size * 2;
} else {
return subgroup_size * 4;
@ -3222,8 +3220,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
return {fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache)[0], 1, 1};
};
const bool disable_subgroups = device->vendor_id == VK_VENDOR_ID_INTEL;
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, uint32_t flags) -> std::vector<uint32_t> {
// For large number of rows, 128 invocations seems to work best.
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
@ -3259,7 +3255,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
bool aligned = fa.first.aligned; \
bool f32acc = fa.first.f32acc; \
uint32_t flags = fa.first.flags; \
bool fa_ds = path == FA_SCALAR && disable_subgroups; \
bool fa_ds = fa_disable_subgroups(device, path); \
uint32_t fa_sgs = fa_subgroup_size(device, path); \
if (path == FAPATH) { \
if (aligned) { \

View File

@ -19,7 +19,7 @@
const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
const uint32_t row_split = (Br < 4) ? 1 : 4;
const uint32_t row_split = (Br < 4 || HSK <= 64) ? 1 : 4;
const uint32_t rows_per_thread = Br / row_split;
const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
@ -205,32 +205,61 @@ void main() {
barrier();
}
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
FLOAT_TYPEV4 Q_cache[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid];
}
// More d iterations means Q register caching becomes relevant
// Few iterations means the additional registers needed are worse than the speed-up from caching
if (HSK_per_thread / 4 > 4) {
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
FLOAT_TYPEV4 Q_cache[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid];
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
FLOAT_TYPEV4 K_Tf;
if (SHMEM_STAGING != 0) {
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
#endif
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));
}
}
}
} else {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
FLOAT_TYPEV4 K_Tf;
if (SHMEM_STAGING != 0) {
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
FLOAT_TYPEV4 K_Tf;
if (SHMEM_STAGING != 0) {
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
#endif
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf));
}
}
}
}