diff --git a/ggml/src/ggml-cuda/cpy-utils.cuh b/ggml/src/ggml-cuda/cpy-utils.cuh index e621cb9811..7697c292dd 100644 --- a/ggml/src/ggml-cuda/cpy-utils.cuh +++ b/ggml/src/ggml-cuda/cpy-utils.cuh @@ -212,6 +212,6 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { } template -static __device__ void cpy_1_flt(const char * cxi, char * cdsti) { +static __device__ void cpy_1_scalar(const char * cxi, char * cdsti) { *(dst_t *) cdsti = ggml_cuda_cast(*(const src_t *) cxi); } diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index c1afde9627..82367ad3fb 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -12,10 +12,10 @@ const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows template -static __global__ void cpy_flt(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { +static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13) { const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { @@ -40,7 +40,7 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne, } template -static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne, +static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13) { @@ -166,7 +166,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne, } template -static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const int64_t ne) { +static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) { const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { @@ -180,17 +180,17 @@ static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const in } template -static void ggml_cpy_flt_contiguous_cuda( +static void ggml_cpy_scalar_contiguous_cuda( const char * cx, char * cdst, const int64_t ne, cudaStream_t stream) { const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_flt_contiguous<<>> + cpy_scalar_contiguous<<>> (cx, cdst, ne); } template -static void ggml_cpy_flt_cuda( +static void ggml_cpy_scalar_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { @@ -212,11 +212,11 @@ static void ggml_cpy_flt_cuda( (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM); dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); - cpy_flt_transpose<<>> + cpy_scalar_transpose<<>> (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } else { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_flt><<>> + cpy_scalar><<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } } @@ -399,94 +399,132 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { if (can_be_transposed) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { - ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q8_0_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_q8_0_f32_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { - ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q4_0_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_q4_0_f32_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { - ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q4_1_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_q4_1_f32_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { - ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q5_0_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_q5_0_f32_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { - ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_iq4_nl_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { - ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q5_1_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_q5_1_f32_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { if (can_be_transposed) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { if (can_be_transposed) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) { + if (can_be_transposed) { + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else { + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { if (contiguous_srcs) { - ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + ggml_cpy_scalar_contiguous_cuda + (src0_ddc, src1_ddc, ne, main_stream); } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_scalar_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 8fe0899bb5..0b29074f33 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4115,6 +4115,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) { return true; } + if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_I32) { + return true; + } if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) { return true; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 90bdb80e5e..ce8c068d7a 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6953,9 +6953,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0})); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0})); - for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) { + for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_I32, GGML_TYPE_F16, GGML_TYPE_BF16 }) { for (bool use_view_slice : { true, false }) { for (std::array ne : std::initializer_list>{ {2, 1, 1, 1}, {2, 1, 3, 5}, {2, 3, 5, 7}, {1, 4, 4, 1}, {1, 8, 17, 1}, {10, 10, 10, 1} }) {