make sure there are enough channels for split-k
This commit is contained in:
parent
09e3a5f07d
commit
311213d209
|
|
@ -991,9 +991,8 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
|
||||||
const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
|
const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
|
||||||
|
|
||||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||||
|
const unsigned int ksplit = 8;
|
||||||
if (BlocksM * BlocksN < nsm) {
|
if (BlocksM * BlocksN < nsm && P.c > 8 * ksplit) {
|
||||||
const unsigned int ksplit = 8;
|
|
||||||
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n);
|
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n);
|
||||||
|
|
||||||
cudaFuncSetAttribute(conv2d_implicit_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit, NumThreads>,
|
cudaFuncSetAttribute(conv2d_implicit_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit, NumThreads>,
|
||||||
|
|
|
||||||
|
|
@ -324,6 +324,7 @@ int main(void)
|
||||||
std::make_tuple(256,128,832,1216,3,3),
|
std::make_tuple(256,128,832,1216,3,3),
|
||||||
std::make_tuple(256,256,832,1216,3,3),
|
std::make_tuple(256,256,832,1216,3,3),
|
||||||
// std::make_tuple(320,256,1024,1920)
|
// std::make_tuple(320,256,1024,1920)
|
||||||
|
// std::make_tuple(32,64,58,58,3,3)
|
||||||
};
|
};
|
||||||
|
|
||||||
int k = 0;
|
int k = 0;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue