diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index ae3f79fd0d..3dd12e177f 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -484,7 +484,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_scale_f32, kernel_scale_f32_4; cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4; cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4; - cl_kernel kernel_mean_f32; + cl_kernel kernel_mean_f32, kernel_mean_f32_4; cl_kernel kernel_silu, kernel_silu_4; cl_kernel kernel_gelu, kernel_gelu_4; cl_kernel kernel_gelu_erf, kernel_gelu_erf_4; @@ -543,7 +543,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_solve_tri_f32; cl_kernel kernel_im2col_f32, kernel_im2col_f16; cl_kernel kernel_argsort_f32_i32; - cl_kernel kernel_sum_rows_f32; + cl_kernel kernel_sum_rows_f32, kernel_sum_rows_f32_4; cl_kernel kernel_repeat_f32; cl_kernel kernel_pad; cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc; @@ -1837,6 +1837,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); CL_CHECK((backend_ctx->kernel_mean_f32 = clCreateKernel(prog, "kernel_mean_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_mean_f32_4 = clCreateKernel(prog, "kernel_mean_f32_4", &err), err)); CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); @@ -1874,6 +1875,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); CL_CHECK((backend_ctx->kernel_sum_rows_f32 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_sum_rows_f32_4 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32_4", &err), err)); GGML_LOG_CONT("."); } @@ -3587,7 +3589,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te } case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: - return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_FLASH_ATTN_EXT: { const ggml_tensor * q = op->src[0]; @@ -6400,7 +6402,6 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const GGML_UNUSED(src1); GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - GGML_ASSERT(ggml_is_contiguous(src0)); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -6423,7 +6424,14 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const const cl_ulong nb2 = dst->nb[2]; const cl_ulong nb3 = dst->nb[3]; - cl_kernel kernel = backend_ctx->kernel_mean_f32; + cl_kernel kernel; + + const bool is_c4 = ne00 % 4 == 0; + if (is_c4) { + kernel = backend_ctx->kernel_mean_f32_4; + } else { + kernel = backend_ctx->kernel_mean_f32; + } CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); @@ -6440,7 +6448,7 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2)); CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3)); - size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03}; + size_t global_work_size[] = {64 * (size_t)ne01, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)64, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); @@ -11088,7 +11096,6 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c GGML_UNUSED(src1); GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - GGML_ASSERT(ggml_is_contiguous(src0)); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -11111,7 +11118,14 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c const cl_ulong nb2 = dst->nb[2]; const cl_ulong nb3 = dst->nb[3]; - cl_kernel kernel = backend_ctx->kernel_sum_rows_f32; + cl_kernel kernel; + + const bool is_c4 = ne00 % 4 == 0; + if (is_c4) { + kernel = backend_ctx->kernel_sum_rows_f32_4; + } else { + kernel = backend_ctx->kernel_sum_rows_f32; + } CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); @@ -11128,7 +11142,7 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2)); CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3)); - size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03}; + size_t global_work_size[] = {64 * (size_t)ne01, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)64, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); diff --git a/ggml/src/ggml-opencl/kernels/mean.cl b/ggml/src/ggml-opencl/kernels/mean.cl index 5c3e8bcd86..ab7dac4678 100644 --- a/ggml/src/ggml-opencl/kernels/mean.cl +++ b/ggml/src/ggml-opencl/kernels/mean.cl @@ -1,8 +1,11 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#define MAX_SUBGROUPS 64 kernel void kernel_mean_f32( - global float * src0, + global char * src0, ulong offset0, - global float * dst, + global char * dst, ulong offsetd, int ne00, int ne01, @@ -15,25 +18,121 @@ kernel void kernel_mean_f32( ulong nb2, ulong nb3 ) { - src0 = (global float *)((global char *)src0 + offset0); - dst = (global float *)((global char *)dst + offsetd); + src0 = src0 + offset0; + dst = dst + offsetd; - int i3 = get_global_id(2); - int i2 = get_global_id(1); - int i1 = get_global_id(0); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int lid = get_local_id(0); + const int lsize = get_local_size(0); + + const uint sg_size = get_sub_group_size(); + const uint sg_id = get_sub_group_id(); + const uint sg_lid = get_sub_group_local_id(); + + __local float lmem[MAX_SUBGROUPS]; if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { return; } - global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); - global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3); - - float row_sum = 0; - - for (int i0 = 0; i0 < ne00; i0++) { - row_sum += src_row[i0]; + if(sg_id == 0){ + lmem[sg_lid] = 0.0f; } - dst_row[0] = row_sum / ne00; + global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3); + + float sumf = 0.0f; + + for (int i0 = lid; i0 < ne00; i0 += lsize) { + sumf += src_row[i0]; + } + + sumf = sub_group_reduce_add(sumf); + + barrier(CLK_LOCAL_MEM_FENCE); + + if(sg_lid == 0){ + lmem[sg_id] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + sumf = lmem[sg_lid]; + sumf = sub_group_reduce_add(sumf); + + if (lid == 0) { + dst_row[0] = sumf / ne00; + } +} + +kernel void kernel_mean_f32_4( + global char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int lid = get_local_id(0); + const int lsize = get_local_size(0); + + const uint sg_size = get_sub_group_size(); + const uint sg_id = get_sub_group_id(); + const uint sg_lid = get_sub_group_local_id(); + + __local float lmem[MAX_SUBGROUPS]; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + if(sg_id == 0){ + lmem[sg_lid] = 0.0f; + } + + global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3); + + float4 sum_vec = (float4)0.0f; + + for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) { + sum_vec += src_row[i0]; + } + + float sumf = dot(sum_vec, (float4)(1.0f)); + sumf = sub_group_reduce_add(sumf); + + barrier(CLK_LOCAL_MEM_FENCE); + + if(sg_lid == 0){ + lmem[sg_id] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + sumf = lmem[sg_lid]; + sumf = sub_group_reduce_add(sumf); + + if (lid == 0) { + dst_row[0] = sumf / ne00; + } } diff --git a/ggml/src/ggml-opencl/kernels/sum_rows.cl b/ggml/src/ggml-opencl/kernels/sum_rows.cl index c5f7c570f9..ed11ddab8e 100644 --- a/ggml/src/ggml-opencl/kernels/sum_rows.cl +++ b/ggml/src/ggml-opencl/kernels/sum_rows.cl @@ -1,8 +1,11 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#define MAX_SUBGROUPS 64 kernel void kernel_sum_rows_f32( - global float * src0, + global char * src0, ulong offset0, - global float * dst, + global char * dst, ulong offsetd, int ne00, int ne01, @@ -15,25 +18,121 @@ kernel void kernel_sum_rows_f32( ulong nb2, ulong nb3 ) { - src0 = (global float *)((global char *)src0 + offset0); - dst = (global float *)((global char *)dst + offsetd); + src0 = src0 + offset0; + dst = dst + offsetd; - int i3 = get_global_id(2); - int i2 = get_global_id(1); - int i1 = get_global_id(0); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int lid = get_local_id(0); + const int lsize = get_local_size(0); + + const uint sg_size = get_sub_group_size(); + const uint sg_id = get_sub_group_id(); + const uint sg_lid = get_sub_group_local_id(); + + __local float lmem[MAX_SUBGROUPS]; if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { return; } - global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); - global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3); - - float row_sum = 0; - - for (int i0 = 0; i0 < ne00; i0++) { - row_sum += src_row[i0]; + if(sg_id == 0){ + lmem[sg_lid] = 0.0f; } - dst_row[0] = row_sum; + global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3); + + float sumf = 0.0f; + + for (int i0 = lid; i0 < ne00; i0 += lsize) { + sumf += src_row[i0]; + } + + sumf = sub_group_reduce_add(sumf); + + barrier(CLK_LOCAL_MEM_FENCE); + + if(sg_lid == 0){ + lmem[sg_id] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + sumf = lmem[sg_lid]; + sumf = sub_group_reduce_add(sumf); + + if (lid == 0) { + dst_row[0] = sumf; + } +} + +kernel void kernel_sum_rows_f32_4( + global char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int lid = get_local_id(0); + const int lsize = get_local_size(0); + + const uint sg_size = get_sub_group_size(); + const uint sg_id = get_sub_group_id(); + const uint sg_lid = get_sub_group_local_id(); + + __local float lmem[MAX_SUBGROUPS]; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + if(sg_id == 0){ + lmem[sg_lid] = 0.0f; + } + + global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3); + + float4 sum_vec = (float4)0.0f; + + for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) { + sum_vec += src_row[i0]; + } + + float sumf = dot(sum_vec, (float4)(1.0f)); + sumf = sub_group_reduce_add(sumf); + + barrier(CLK_LOCAL_MEM_FENCE); + + if(sg_lid == 0){ + lmem[sg_id] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + sumf = lmem[sg_lid]; + sumf = sub_group_reduce_add(sumf); + + if (lid == 0) { + dst_row[0] = sumf; + } }