diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index c4ceb4fc57..1222b11fdf 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -200,23 +200,13 @@ static void ggml_cpy_scalar_cuda( if (transposed) { GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed - int ne00n, ne01n, ne02n; - if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here - ne00n = ne00; - ne01n = ne01; - ne02n = ne02; - } else { - ne00n = ne00; - ne01n = ne01*ne02; - ne02n = 1; - } - dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, - (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, - (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM); + dim3 dimGrid( (ne01 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, + (ne00 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, + (ne/(ne01*ne00) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM); dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); cpy_scalar_transpose<<>> - (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } else { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_scalar><<>> @@ -359,9 +349,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; + int64_t ne00 = src0->ne[0]; + int64_t ne01 = src0->ne[1]; + int64_t ne02 = src0->ne[2]; + //GGML_ASSERT(src0->ne[3] == 1); @@ -387,8 +378,39 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char * src1_ddc = (char *) src1->data; const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1); - const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && - src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0); + + bool can_be_transposed = false; + if (src0->ne[3] == 1 ) { + int64_t ne00n, ne01n, ne02n; + if (nb01 == (int64_t)ggml_element_size(src0) && + (nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0) || + nb00 == ne01 * ne02 * (int64_t)ggml_element_size(src0))) { + if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here + ne00n = ne00; + ne01n = ne01; + ne02n = ne02; + } else { + ne00n = ne00; + ne01n = ne01*ne02; + ne02n = 1; + } + ne00 = ne00n ; + ne01 = ne01n; + ne02 = ne02n; + can_be_transposed = true; + } + if ((nb02 == (int64_t)ggml_element_size(src0) && + nb01 == ne02 * ne00 * (int64_t)ggml_element_size(src0))) { + GGML_ASSERT(nb00 <= nb01); + ne00n = ne00*ne01; + ne01n = ne02; + ne02n = 1; // not used + ne00 = ne00n ; + ne01 = ne01n; + ne02 = ne02n; + can_be_transposed = true; + } + } if (src0->type == src1->type && contiguous_srcs) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 7ef7f2ad81..45a6dbc498 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7928,6 +7928,11 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {4352, 1, 9216, 1}, {1, 2, 0, 3}, {0, 0, 0, 0})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {4352, 1, 9216, 1}, {1, 2, 0, 3}, {0, 0, 0, 0})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {21504, 4352, 1, 1}, {2, 0, 1, 3}, {0, 0, 0, 0})); + + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));