relax the transposed copy condition check a bit and add another case

This commit is contained in:
bssrdf 2025-12-03 21:39:05 -05:00
parent dea9ba27cb
commit adb72c52f1
2 changed files with 46 additions and 19 deletions

View File

@ -200,23 +200,13 @@ static void ggml_cpy_scalar_cuda(
if (transposed) {
GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
int ne00n, ne01n, ne02n;
if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here
ne00n = ne00;
ne01n = ne01;
ne02n = ne02;
} else {
ne00n = ne00;
ne01n = ne01*ne02;
ne02n = 1;
}
dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
(ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
(ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
dim3 dimGrid( (ne01 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
(ne00 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
(ne/(ne01*ne00) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} else {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
@ -359,9 +349,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
int64_t ne00 = src0->ne[0];
int64_t ne01 = src0->ne[1];
int64_t ne02 = src0->ne[2];
//GGML_ASSERT(src0->ne[3] == 1);
@ -387,8 +378,39 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
char * src1_ddc = (char *) src1->data;
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);
bool can_be_transposed = false;
if (src0->ne[3] == 1 ) {
int64_t ne00n, ne01n, ne02n;
if (nb01 == (int64_t)ggml_element_size(src0) &&
(nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0) ||
nb00 == ne01 * ne02 * (int64_t)ggml_element_size(src0))) {
if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here
ne00n = ne00;
ne01n = ne01;
ne02n = ne02;
} else {
ne00n = ne00;
ne01n = ne01*ne02;
ne02n = 1;
}
ne00 = ne00n ;
ne01 = ne01n;
ne02 = ne02n;
can_be_transposed = true;
}
if ((nb02 == (int64_t)ggml_element_size(src0) &&
nb01 == ne02 * ne00 * (int64_t)ggml_element_size(src0))) {
GGML_ASSERT(nb00 <= nb01);
ne00n = ne00*ne01;
ne01n = ne02;
ne02n = 1; // not used
ne00 = ne00n ;
ne01 = ne01n;
ne02 = ne02n;
can_be_transposed = true;
}
}
if (src0->type == src1->type && contiguous_srcs) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));

View File

@ -7928,6 +7928,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {4352, 1, 9216, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {4352, 1, 9216, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {21504, 4352, 1, 1}, {2, 0, 1, 3}, {0, 0, 0, 0}));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));