use a better criterian to use split-k
This commit is contained in:
parent
688de6d7d8
commit
d9a48580fc
|
|
@ -984,9 +984,9 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
|
|||
constexpr unsigned int NumThreads = ThreadsM * ThreadsN;
|
||||
const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
|
||||
|
||||
const unsigned int K2MN = 8;
|
||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||
|
||||
if (P.c * P.r * P.s > K2MN * P.n * P.Oh * P.Ow || P.c * P.r * P.s > K2MN * P.k) {
|
||||
if (BlocksM * BlocksN < nsm) {
|
||||
const unsigned int ksplit = 8;
|
||||
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue