fix regressions
This commit is contained in:
parent
9f9a8743c4
commit
dd92b1f8d5
|
|
@ -2806,9 +2806,7 @@ static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint
|
|||
case FA_COOPMAT1:
|
||||
return (Bc / 16) * subgroup_size; // enough subgroups for Bc/MatBc
|
||||
default:
|
||||
if (device->vendor_id == VK_VENDOR_ID_INTEL) {
|
||||
return 128;
|
||||
} else if (subgroup_size > 32 && Br < 4) {
|
||||
if (subgroup_size > 32 && Br < 4) {
|
||||
return subgroup_size * 2;
|
||||
} else {
|
||||
return subgroup_size * 4;
|
||||
|
|
@ -3222,8 +3220,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
return {fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache)[0], 1, 1};
|
||||
};
|
||||
|
||||
const bool disable_subgroups = device->vendor_id == VK_VENDOR_ID_INTEL;
|
||||
|
||||
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, uint32_t flags) -> std::vector<uint32_t> {
|
||||
// For large number of rows, 128 invocations seems to work best.
|
||||
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
||||
|
|
@ -3259,7 +3255,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
bool aligned = fa.first.aligned; \
|
||||
bool f32acc = fa.first.f32acc; \
|
||||
uint32_t flags = fa.first.flags; \
|
||||
bool fa_ds = path == FA_SCALAR && disable_subgroups; \
|
||||
bool fa_ds = fa_disable_subgroups(device, path); \
|
||||
uint32_t fa_sgs = fa_subgroup_size(device, path); \
|
||||
if (path == FAPATH) { \
|
||||
if (aligned) { \
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@
|
|||
const uint32_t HSK_per_thread = HSK / D_split;
|
||||
const uint32_t HSV_per_thread = HSV / D_split;
|
||||
|
||||
const uint32_t row_split = (Br < 4) ? 1 : 4;
|
||||
const uint32_t row_split = (Br < 4 || HSK <= 64) ? 1 : 4;
|
||||
const uint32_t rows_per_thread = Br / row_split;
|
||||
const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split;
|
||||
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
||||
|
|
@ -205,32 +205,61 @@ void main() {
|
|||
barrier();
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
|
||||
FLOAT_TYPEV4 Q_cache[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid];
|
||||
}
|
||||
// More d iterations means Q register caching becomes relevant
|
||||
// Few iterations means the additional registers needed are worse than the speed-up from caching
|
||||
if (HSK_per_thread / 4 > 4) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
|
||||
FLOAT_TYPEV4 Q_cache[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
|
||||
FLOAT_TYPEV4 K_Tf;
|
||||
if (SHMEM_STAGING != 0) {
|
||||
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
|
||||
} else {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
||||
#endif
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
|
||||
FLOAT_TYPEV4 K_Tf;
|
||||
if (SHMEM_STAGING != 0) {
|
||||
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
|
||||
} else {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
|
||||
FLOAT_TYPEV4 K_Tf;
|
||||
if (SHMEM_STAGING != 0) {
|
||||
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
|
||||
} else {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
||||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
||||
#endif
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue