diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 0b9a021d20..a4403a5c27 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -416,7 +416,6 @@ struct ggml_backend_opencl_context { cl_program program_add; cl_program program_add_id; cl_program program_clamp; - cl_program program_cpy; cl_program program_cvt; cl_program program_diag_mask_inf; cl_program program_gelu; @@ -514,7 +513,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_set_rows_f32_i64, kernel_set_rows_f32_i32, kernel_set_rows_f16_i64, kernel_set_rows_f16_i32; cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16; - cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32; + cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32, kernel_cpy_i32_i32; cl_kernel kernel_mul_mat_f32_f32; cl_kernel kernel_mul_mat_f16_f16; cl_kernel kernel_mul_mat_f16_f32_1row; @@ -873,13 +872,14 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve #else const std::string kernel_src = read_file("cpy.cl"); #endif - backend_ctx->program_cpy = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f16_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f16_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f32_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f32_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(prog, "kernel_cpy_f16_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(prog, "kernel_cpy_f16_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(prog, "kernel_cpy_f32_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(prog, "kernel_cpy_f32_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_i32_i32 = clCreateKernel(prog, "kernel_cpy_i32_i32", &err), err)); GGML_LOG_CONT("."); } @@ -3544,9 +3544,21 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te default: return false; } + case GGML_TYPE_I32: + switch (op->type) { + case GGML_TYPE_I32: + return true; + default: + return false; + } default: return false; } + case GGML_OP_SET: { + return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32) && + op->type == op->src[0]->type && + op->type == op->src[1]->type; + } case GGML_OP_SCALE: return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); case GGML_OP_ADD: @@ -10782,28 +10794,13 @@ static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const // GGML_OP_DUP and GGML_OP_CONT happen between src0 and dst. UNUSED(dst); - const int ne00 = src0 ? src0->ne[0] : 0; - const int ne01 = src0 ? src0->ne[1] : 0; - const int ne02 = src0 ? src0->ne[2] : 0; - const int ne03 = src0 ? src0->ne[3] : 0; + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne1, src1, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb); - const cl_ulong nb00 = src0 ? src0->nb[0] : 0; - const cl_ulong nb01 = src0 ? src0->nb[1] : 0; - const cl_ulong nb02 = src0 ? src0->nb[2] : 0; - const cl_ulong nb03 = src0 ? src0->nb[3] : 0; - - const int ne10 = src1 ? src1->ne[0] : 0; - const int ne11 = src1 ? src1->ne[1] : 0; - const int ne12 = src1 ? src1->ne[2] : 0; - const int ne13 = src1 ? src1->ne[3] : 0; - - const cl_ulong nb10 = src1 ? src1->nb[0] : 0; - const cl_ulong nb11 = src1 ? src1->nb[1] : 0; - const cl_ulong nb12 = src1 ? src1->nb[2] : 0; - const cl_ulong nb13 = src1 ? src1->nb[3] : 0; - - const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; - const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + const enum ggml_type src0t = src0->type; + const enum ggml_type src1t = src1->type; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -10840,6 +10837,15 @@ static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const GGML_ASSERT(false && "not implemented"); } break; + case GGML_TYPE_I32: + switch (src1t) { + case GGML_TYPE_I32: + kernel = backend_ctx->kernel_cpy_i32_i32; + break; + default: + GGML_ASSERT(false && "not implemented"); + } + break; default: GGML_ASSERT(false && "not implemented"); } @@ -10878,6 +10884,89 @@ static void ggml_cl_dup(ggml_backend_t backend, const ggml_tensor * src0, const UNUSED(src1); } +static void ggml_cl_set(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32) && + src1->type == src0->type && dst->type == src0->type); + + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne1, src1, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb); + GGML_TENSOR_LOCALS(int, ne, dst, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb, dst, nb); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const cl_ulong pnb1 = ((const int32_t *)dst->op_params)[0]; + const cl_ulong pnb2 = ((const int32_t *)dst->op_params)[1]; + const cl_ulong pnb3 = ((const int32_t *)dst->op_params)[2]; + const cl_ulong offs = ((const int32_t *)dst->op_params)[3]; + const bool inplace = (bool)((const int32_t *)dst->op_params)[4]; + + cl_kernel kernel = nullptr; + + // for inplace case, dst is a view of src0 and is updated on top of it + // so for non-inplace case, copy src0 to dst first + if (!inplace) { + ggml_cl_cpy(backend, src0, dst, nullptr); + } + + // then copy src1 to dst with specified offset + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_cpy_f32_f32; + } else if (src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { + kernel = backend_ctx->kernel_cpy_i32_i32; + } else { + GGML_ASSERT(false && "not implemented"); + } + + offsetd += offs; + cl_ulong nb = ggml_element_size(dst); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &pnb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &pnb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &pnb3)); + + int max_local_size = backend_ctx->get_kernel_workgroup_size(kernel); + + const int nth = MIN(max_local_size, ne00); + + size_t global_work_size[] = {(size_t)ne11*nth, (size_t)ne12, (size_t)ne13}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -11651,6 +11740,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_cpy; break; + case GGML_OP_SET: + if (!any_on_device) { + return false; + } + func = ggml_cl_set; + break; case GGML_OP_DUP: case GGML_OP_CONT: if (!any_on_device) { diff --git a/ggml/src/ggml-opencl/kernels/cpy.cl b/ggml/src/ggml-opencl/kernels/cpy.cl index 9369351a60..820aa538a3 100644 --- a/ggml/src/ggml-opencl/kernels/cpy.cl +++ b/ggml/src/ggml-opencl/kernels/cpy.cl @@ -182,3 +182,48 @@ kernel void kernel_cpy_f32_f32( dst_data[i00] = src[0]; } } + +kernel void kernel_cpy_i32_i32( + global int * src0, + ulong offset0, + global int * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global int*)((global char*)src0 + offset0); + dst = (global int*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global int * dst_data = (global int *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const int * src = (global int *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +}