make split-k condition check more robust

This commit is contained in:
bssrdf 2025-11-08 18:47:12 -05:00
parent a1fb3c1509
commit a3fb36fb71
2 changed files with 3 additions and 3 deletions

View File

@ -859,7 +859,7 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
// if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) {
if (BlocksM * BlocksN < nsm){
int ks = nsm / (BlocksM * BlocksN);
int ks = min(12, nsm / (BlocksM * BlocksN));
int j;
bool can_split = false;
for (j = ks; j >= 2; j--){
@ -909,7 +909,7 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
const unsigned int ksplit = 11;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else {
} else if(j == 12) {
const unsigned int ksplit = 12;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);

View File

@ -653,7 +653,7 @@ int main(void)
int k = 0;
for (auto c : configs_sdxl_768){
for (auto c : configs_sdxl_1024){
// for (auto c : configs){
test_model model;
load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c),