diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index c2cc1930cb..afca57459a 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -306,26 +306,27 @@ int main(void) { ggml_time_init(); std::vector> configs = { - // std::make_tuple(64,64,48,64,3,3), - // std::make_tuple(320,320,104,152,3,3), - // std::make_tuple(640,640,52,76,3,3), - // std::make_tuple(640,640,104,152,3,3), - // std::make_tuple(960,320,104,152,3,3), - // std::make_tuple(1280,1280,26,38,3,3), - // std::make_tuple(1280,1280,26,38,1,1), - // std::make_tuple(256,128,768,1024,3,3), - // std::make_tuple(128,3,768,1024,3,3), - // std::make_tuple(256,128,768,1024,1,1), - // std::make_tuple(512,256,384,512,1,1), - // std::make_tuple(1280,640,52,76,3,3), - // std::make_tuple(1920,1280,26,38,3,3), - // std::make_tuple(2560,1280,26,38,3,3), + std::make_tuple(64,64,48,64,3,3), + std::make_tuple(320,320,104,152,3,3), + std::make_tuple(640,640,52,76,3,3), + std::make_tuple(640,640,104,152,3,3), + std::make_tuple(960,320,104,152,3,3), + std::make_tuple(1280,1280,26,38,3,3), std::make_tuple(320,1280,26,38,3,3), - // std::make_tuple(512,512,104,152,3,3), - // std::make_tuple(512,512,208,304,3,3), - // std::make_tuple(512,256,416,608,3,3), - // std::make_tuple(256,128,832,1216,3,3), - // std::make_tuple(256,256,832,1216,3,3), + std::make_tuple(1280,1280,26,38,1,1), + std::make_tuple(256,128,768,1024,3,3), + std::make_tuple(128,3,768,1024,3,3), + std::make_tuple(256,128,768,1024,1,1), + std::make_tuple(512,256,384,512,1,1), + std::make_tuple(1280,640,52,76,3,3), + std::make_tuple(1920,1280,26,38,3,3), + std::make_tuple(2560,1280,26,38,3,3), + std::make_tuple(320,1280,26,38,3,3), + std::make_tuple(512,512,104,152,3,3), + std::make_tuple(512,512,208,304,3,3), + std::make_tuple(512,256,416,608,3,3), + std::make_tuple(256,128,832,1216,3,3), + std::make_tuple(256,256,832,1216,3,3), // std::make_tuple(320,256,1024,1920) }; @@ -377,7 +378,7 @@ int main(void) if(k==0) { k = 1; - fprintf(stderr, "| (IC, OC, IW, IH) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); + fprintf(stderr, "| (IC, OC, IW, IH, KW, KH) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); fprintf(stderr, "| --- | --- | --- | --- | --- \n"); }