make split-k condition check more robust
This commit is contained in:
parent
a1fb3c1509
commit
a3fb36fb71
|
|
@ -859,7 +859,7 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
|
|||
// if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) {
|
||||
if (BlocksM * BlocksN < nsm){
|
||||
|
||||
int ks = nsm / (BlocksM * BlocksN);
|
||||
int ks = min(12, nsm / (BlocksM * BlocksN));
|
||||
int j;
|
||||
bool can_split = false;
|
||||
for (j = ks; j >= 2; j--){
|
||||
|
|
@ -909,7 +909,7 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
|
|||
const unsigned int ksplit = 11;
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else {
|
||||
} else if(j == 12) {
|
||||
const unsigned int ksplit = 12;
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
|
|
|
|||
|
|
@ -653,7 +653,7 @@ int main(void)
|
|||
|
||||
int k = 0;
|
||||
|
||||
for (auto c : configs_sdxl_768){
|
||||
for (auto c : configs_sdxl_1024){
|
||||
// for (auto c : configs){
|
||||
test_model model;
|
||||
load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c),
|
||||
|
|
|
|||
Loading…
Reference in New Issue