This commit is contained in:
bssrdf 2025-12-16 22:15:14 -05:00 committed by GitHub
commit a01c57deef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 64 additions and 25 deletions

View File

@ -39,7 +39,7 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne,
cpy_1(cx + x_offset, cdst + dst_offset); cpy_1(cx + x_offset, cdst + dst_offset);
} }
template <typename T> template <typename T, int swap>
static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne, static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
@ -51,10 +51,10 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
const int64_t nmat = ne / (ne00 * ne01); const int64_t nmat = ne / (ne00 * ne01);
const int64_t n = ne00 * ne01; const int64_t n = ne00 * ne01;
const int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x; const int x = (swap == 0 ? blockIdx.x : blockIdx.y) * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
const int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y; const int y = (swap == 0 ? blockIdx.y : blockIdx.x) * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset const int tx = (swap == 0 ? blockIdx.y : blockIdx.x) * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y; const int ty = (swap == 0 ? blockIdx.x : blockIdx.y) * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
__shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1]; __shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
@ -200,23 +200,21 @@ static void ggml_cpy_scalar_cuda(
if (transposed) { if (transposed) {
GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
int ne00n, ne01n, ne02n; if(ne01 > ne00) {
if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here dim3 dimGrid( (ne01 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
ne00n = ne00; (ne00 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
ne01n = ne01; (ne/(ne01*ne00) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
ne02n = ne02;
} else {
ne00n = ne00;
ne01n = ne01*ne02;
ne02n = 1;
}
dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
(ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
(ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>> cpy_scalar_transpose<dst_t, 0><<<dimGrid, dimBlock, 0, stream>>>
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} else {
dim3 dimGrid( (ne00 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
(ne01 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
(ne/(ne01*ne00) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
cpy_scalar_transpose<dst_t, 1><<<dimGrid, dimBlock, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
} else { } else {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>> cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
@ -359,9 +357,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
const int64_t ne00 = src0->ne[0]; int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1]; int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2]; int64_t ne02 = src0->ne[2];
//GGML_ASSERT(src0->ne[3] == 1); //GGML_ASSERT(src0->ne[3] == 1);
@ -387,8 +386,39 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
char * src1_ddc = (char *) src1->data; char * src1_ddc = (char *) src1->data;
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1); const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0); bool can_be_transposed = false;
if (src0->ne[3] == 1 ) {
int64_t ne00n, ne01n, ne02n;
if (nb01 == (int64_t)ggml_element_size(src0) &&
(nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0) ||
nb00 == ne01 * ne02 * (int64_t)ggml_element_size(src0))) {
if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here
ne00n = ne00;
ne01n = ne01;
ne02n = ne02;
} else {
ne00n = ne00;
ne01n = ne01*ne02;
ne02n = 1;
}
ne00 = ne00n;
ne01 = ne01n;
ne02 = ne02n;
can_be_transposed = true;
}
if ((nb02 == (int64_t)ggml_element_size(src0) && nb00 <= nb01 &&
nb01 == ne02 * ne00 * (int64_t)ggml_element_size(src0))) {
// GGML_ASSERT(nb00 <= nb01);
ne00n = ne00*ne01;
ne01n = ne02;
ne02n = 1; // not used
ne00 = ne00n;
ne01 = ne01n;
ne02 = ne02n;
can_be_transposed = true;
}
}
if (src0->type == src1->type && contiguous_srcs) { if (src0->type == src1->type && contiguous_srcs) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));

View File

@ -7173,6 +7173,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0})); test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0})); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {4, 1, 256, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {16, 256, 1, 1}, {2, 0, 1, 3}, {0, 0, 0, 0}));
for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_I32, GGML_TYPE_F16, GGML_TYPE_BF16 }) { for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_I32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {
for (bool use_view_slice : { true, false }) { for (bool use_view_slice : { true, false }) {
@ -8063,6 +8066,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
// sd.cpp cases
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {4352, 1, 9216, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {4352, 1, 9216, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {21504, 4352, 1, 1}, {2, 0, 1, 3}, {0, 0, 0, 0}));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));