WIP: fixed a bug in cpy transpos index computation
This commit is contained in:
parent
a3784e17ad
commit
6d12288037
|
|
@ -52,7 +52,7 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst_direct, co
|
|||
const int64_t nmat = ne /(ne00 * ne01);
|
||||
const int64_t n = ne00 * ne01;
|
||||
// const int64_t n = ne01 * ne02;
|
||||
int width = gridDim.x * TILE_DIM;
|
||||
int width = ne01;
|
||||
int x = blockIdx.x * TILE_DIM + threadIdx.x;
|
||||
int y = blockIdx.y * TILE_DIM + threadIdx.y;
|
||||
int tx = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset
|
||||
|
|
@ -194,8 +194,8 @@ static void ggml_cpy_flt_cuda(
|
|||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
if constexpr ((std::is_same_v<src_t, half> && std::is_same_v<dst_t, half> ||
|
||||
|
||||
if constexpr ((std::is_same_v<src_t, half> && std::is_same_v<dst_t, half> ||
|
||||
std::is_same_v<src_t, float> && std::is_same_v<dst_t, float>)
|
||||
&& transpose){
|
||||
// printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11);
|
||||
|
|
@ -203,12 +203,13 @@ static void ggml_cpy_flt_cuda(
|
|||
// if (ne00 == ne11 && ne01 == ne10 && nb00 == nb11 && nb10 == nb01){ //transpose
|
||||
// if (transpose) { //transpose
|
||||
// printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11);
|
||||
dim3 dimGrid( (ne00 + TILE_DIM - 1) / TILE_DIM,
|
||||
(ne01 + TILE_DIM - 1) / TILE_DIM,
|
||||
dim3 dimGrid( (ne01 + TILE_DIM - 1) / TILE_DIM,
|
||||
(ne00 + TILE_DIM - 1) / TILE_DIM,
|
||||
(ne/(ne00*ne01) + BLOCK_NM - 1) / BLOCK_NM );
|
||||
dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1);
|
||||
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
} else{ // other
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
|
@ -401,7 +402,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
if(src1->op_params[10] == 999){
|
||||
if(src0->op_params[10] == 999){
|
||||
ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<float, float, false> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
|
|
@ -436,7 +437,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
if(src1->op_params[10] == 999){
|
||||
if(src0->op_params[10] == 999){
|
||||
ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<half, half, false> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
|
|
|
|||
|
|
@ -3301,9 +3301,6 @@ static struct ggml_tensor * ggml_cont_impl(
|
|||
|
||||
result->op = GGML_OP_CONT;
|
||||
result->src[0] = a;
|
||||
if (a->op == GGML_OP_TRANSPOSE) {
|
||||
result->op_params[10] = a->op_params[10]; // preserve the original order
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue