WIP: debugging cpy transpose

This commit is contained in:
bssrdf 2025-10-27 15:09:03 -04:00
parent cc327f5224
commit a3784e17ad
4 changed files with 81 additions and 40 deletions

View File

@ -39,7 +39,7 @@ static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne
template <typename T>
static __global__ void cpy_flt_transpose(char * cx, char * cdst_direct,, const int ne,
static __global__ void cpy_flt_transpose(const char * cx, char * cdst_direct, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
@ -58,22 +58,31 @@ static __global__ void cpy_flt_transpose(char * cx, char * cdst_direct,, const i
int tx = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset
int ty = blockIdx.x * TILE_DIM + threadIdx.y;
__shared__ T tile[TILE_DIM * TILE_DIM];
// __shared__ T tile[TILE_DIM * TILE_DIM];
__shared__ T tile[TILE_DIM][TILE_DIM];
for(int i = 0; i < BLOCK_NM; ++i){
const unsigned int imat = blockIdx.z * BLOCK_NM + i;
if(imat < nmat){
for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){
const unsigned int idx = (y+j)*width + x;
if(idx < n)
tile[threadIdx.y+j][threadIdx.x] = src[imat*n + idx];
if(idx < n){
const int row = threadIdx.y+j;
const int col = threadIdx.x ^ row;
// tile[threadIdx.y+j][threadIdx.x] = src[imat*n + idx];
tile[row][col] = src[imat*n + idx];
}
}
__syncthreads();
for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){
const unsigned int idx = (ty+j)*width + tx;
if(idx < n)
dst[imat*n + idx] = tile[threadIdx.x][threadIdx.y + j];
if(idx < n){
// const int row = threadIdx.x;
const int col = (threadIdx.y+j) ^ threadIdx.x;
// dst[imat*n + idx] = tile[threadIdx.x][threadIdx.y + j];
dst[imat*n + idx] = tile[threadIdx.x][col];
}
}
}
}
@ -180,30 +189,33 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des
#endif
}
template<typename src_t, typename dst_t>
template<typename src_t, typename dst_t, bool transpose = false>
static void ggml_cpy_flt_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
if constexpr (std::is_same_v<src_t, half> && std::is_same_v<dst_t, half> ||
std::is_same_v<src_t, float> && std::is_same_v<dst_t, float>
){
if (ne00 == ne11 && ne01 = ne10 && nb00 == nb11 && nb10 == nb01){ //transpose
if constexpr ((std::is_same_v<src_t, half> && std::is_same_v<dst_t, half> ||
std::is_same_v<src_t, float> && std::is_same_v<dst_t, float>)
&& transpose){
// printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11);
// printf("cuda cpy transpose nb00=%d nb01=%d nb10=%d nb11=%d\n", nb00, nb01, nb10, nb11);
// if (ne00 == ne11 && ne01 == ne10 && nb00 == nb11 && nb10 == nb01){ //transpose
// if (transpose) { //transpose
// printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11);
dim3 dimGrid( (ne00 + TILE_DIM - 1) / TILE_DIM,
(ne01 + TILE_DIM - 1) / TILE_DIM,
(ne/(ne00*ne01) + BLOCK_NM - 1) / BLOCK_NM );
dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1);
cpy_flt_transpose<cpy_1_flt<dst_t><<<dimGrid, dimBlock, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
} else{ // other
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
}
} else{
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
} else{ // other
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
}
// } else{
// cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
// (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
// }
}
static void ggml_cpy_f32_q8_0_cuda(
@ -389,7 +401,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
if(src1->op_params[10] == 999){
ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else {
ggml_cpy_flt_cuda<float, float, false> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
@ -420,7 +436,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
if(src1->op_params[10] == 999){
ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else {
ggml_cpy_flt_cuda<half, half, false> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
}
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {

View File

@ -3301,6 +3301,9 @@ static struct ggml_tensor * ggml_cont_impl(
result->op = GGML_OP_CONT;
result->src[0] = a;
if (a->op == GGML_OP_TRANSPOSE) {
result->op_params[10] = a->op_params[10]; // preserve the original order
}
return result;
}
@ -3614,6 +3617,7 @@ struct ggml_tensor * ggml_transpose(
result->op = GGML_OP_TRANSPOSE;
result->src[0] = a;
result->op_params[10] = 999; // the transpose flag
return result;
}
@ -4609,8 +4613,18 @@ struct ggml_tensor * ggml_conv_2d_implicitgemm(
struct ggml_tensor *ap, *bp;
if(layout == 0){
ap = ggml_cont(ctx, ggml_permute(ctx, a, 1, 2, 0, 3));
bp = ggml_cont(ctx, ggml_permute(ctx, b, 1, 2, 0, 3));
// ap = ggml_cont(ctx, ggml_permute(ctx, a, 1, 2, 0, 3));
// bp = ggml_cont(ctx, ggml_permute(ctx, b, 1, 2, 0, 3));
ap = ggml_reshape_4d(ctx,
ggml_cont(ctx,
ggml_transpose(ctx,
ggml_reshape_3d(ctx, a, a->ne[0]*a->ne[1], a->ne[2], a->ne[3]))),
a->ne[2], a->ne[0], a->ne[1], a->ne[3]);
bp = ggml_reshape_4d(ctx,
ggml_cont(ctx,
ggml_transpose(ctx,
ggml_reshape_3d(ctx, b, b->ne[0]*b->ne[1], b->ne[2], b->ne[3]))),
b->ne[2], b->ne[0], b->ne[1], b->ne[3]);
} else{
ap = a;
bp = b;

View File

@ -2414,6 +2414,7 @@ struct test_cpy : public test_case {
const std::array<int64_t, 4> permute_dst;
bool _src_use_permute;
bool _dst_use_permute;
bool is_transpose;
std::string vars() override {
return VARS_TO_STR5(type_src, type_dst, ne, permute_src, permute_dst);
@ -2430,10 +2431,12 @@ struct test_cpy : public test_case {
test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 10, 1},
std::array<int64_t, 4> permute_src = {0, 0, 0, 0},
std::array<int64_t, 4> permute_dst = {0, 0, 0, 0})
std::array<int64_t, 4> permute_dst = {0, 0, 0, 0},
bool transpose = false)
: type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst),
_src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0),
_dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0) {}
_dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0),
is_transpose(transpose) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
@ -2454,6 +2457,8 @@ struct test_cpy : public test_case {
}
ggml_tensor * out = ggml_cpy(ctx, src, dst);
if(is_transpose)
dst->op_params[10] = 999;
ggml_set_name(out, "out");
return out;
@ -4258,14 +4263,14 @@ struct test_conv_2d_implicit : public test_case {
ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
ggml_set_name(kernel, "kernel");
if (cwhn) {
// change memory layout to channel-most-contiguous (CWHN),
// then permute it back so NE matches the original input
input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3));
input = ggml_permute(ctx, input, 2, 0, 1, 3);
kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0));
kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1);
}
// if (cwhn) {
// // change memory layout to channel-most-contiguous (CWHN),
// // then permute it back so NE matches the original input
// input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3));
// input = ggml_permute(ctx, input, 2, 0, 1, 3);
// kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0));
// kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1);
// }
ggml_tensor * out =
ggml_conv_2d_implicitgemm(ctx, kernel, input, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn?0:1);
@ -6831,9 +6836,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1}));
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
// test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {1, 0, 2, 3}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {1, 0, 2, 3}, false));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));

View File

@ -353,10 +353,10 @@ int main(void)
// std::make_tuple(640,640,52,76,3,3),
// std::make_tuple(640,640,104,152,3,3),
// std::make_tuple(960,320,104,152,3,3),
// std::make_tuple(1280,1280,26,38,3,3),
std::make_tuple(1280,1280,26,38,3,3),
// std::make_tuple(1280,1280,26,38,1,1),
// std::make_tuple(256,128,768,1024,3,3),
std::make_tuple(256,128,768,1024,1,1),
// std::make_tuple(256,128,768,1024,1,1),
// std::make_tuple(1280,640,52,76,3,3),
// std::make_tuple(1920,1280,26,38,3,3),
// std::make_tuple(2560,1280,26,38,3,3),
@ -451,16 +451,16 @@ int main(void)
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
// for(int i = 0; i < 26*38; i++) {
// // for(int i = 0; i < conv2d_data.size(); i++) {
// for(int i = 0; i < conv2d_data.size(); i++) {
// // float diff = fabs(conv2d_data[i] - wino_data[i]);
// float diff = fabs(im2col_data[i] - wino_data[i]);
// float diff1 = fabs(im2col_data[i] - conv2d_data[i]);
// // if(diff > 1.e-4) {
// if(diff > 0.5) {
// printf("(%7.3f, %7.3f, %7.3f, %.2f, %.2f, %d) \n",
// im2col_data[i], conv2d_data[i],
// wino_data[i], diff, diff1, i);
// // break;
// // }
// }
// }
ggml_free(model.ctx);