diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 3ec0e957ab..3b8d8f87b0 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -2,6 +2,9 @@ #include "dequantize.cuh" #include "convert.cuh" +#define MIN(a, b) (a) < (b) ? (a) : (b) +#define MAX_GRIDDIM_Y 65535 + template static __global__ void k_get_rows( const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, @@ -9,34 +12,36 @@ static __global__ void k_get_rows( /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, - const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { + const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/, const size_t block_num_y) { - // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. - const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2; - const int i10 = blockIdx.x; - const int i11 = blockIdx.z / ne12; - const int i12 = blockIdx.z % ne12; + for (int64_t iy = blockIdx.y; iy < block_num_y; iy+=MAX_GRIDDIM_Y) { + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i00 = (iy * blockDim.x + threadIdx.x)*2; + const int i10 = blockIdx.x; + const int i11 = blockIdx.z / ne12; + const int i12 = blockIdx.z % ne12; - if (i00 >= ne00) { - return; + if (i00 >= ne00) { + return; + } + + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03; + + const int ib = i00/qk; // block index + const int iqs = (i00%qk)/qr; // quant index + const int iybs = i00 - i00%qk; // dst block start index + const int y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + float2 v; + dequantize_kernel(src0_row, ib, iqs, v); + + dst_row[iybs + iqs + 0] = ggml_cuda_cast(v.x); + dst_row[iybs + iqs + y_offset] = ggml_cuda_cast(v.y); } - - const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; - - dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03; - - const int ib = i00/qk; // block index - const int iqs = (i00%qk)/qr; // quant index - const int iybs = i00 - i00%qk; // dst block start index - const int y_offset = qr == 1 ? 1 : qk/2; - - // dequantize - float2 v; - dequantize_kernel(src0_row, ib, iqs, v); - - dst_row[iybs + iqs + 0] = ggml_cuda_cast(v.x); - dst_row[iybs + iqs + y_offset] = ggml_cuda_cast(v.y); } template @@ -46,24 +51,27 @@ static __global__ void k_get_rows_float( /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, - const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { + const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/, const size_t block_num_y) { - // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. - const int i00 = blockIdx.y * blockDim.x + threadIdx.x; - const int i10 = blockIdx.x; - const int i11 = blockIdx.z / ne12; - const int i12 = blockIdx.z % ne12; + for (int64_t iy = blockIdx.y; iy < block_num_y; iy+=MAX_GRIDDIM_Y) { - if (i00 >= ne00) { - return; + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i00 = iy * blockDim.x + threadIdx.x; + const int i10 = blockIdx.x; + const int i11 = blockIdx.z / ne12; + const int i12 = blockIdx.z % ne12; + + if (i00 >= ne00) { + return; + } + + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03); + + dst_row[i00] = ggml_cuda_cast(src0_row[i00]); } - - const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; - - dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03); - - dst_row[i00] = ggml_cuda_cast(src0_row[i00]); } template @@ -98,7 +106,7 @@ static void get_rows_cuda_q( cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); - const dim3 block_nums(ne10, block_num_y, ne11*ne12); + const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12); // strides in elements // const size_t s0 = nb0 / sizeof(dst_t); @@ -119,7 +127,7 @@ static void get_rows_cuda_q( /*ne10, ne11,*/ ne12, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, - s10, s11, s12/*, s13*/); + s10, s11, s12/*, s13*/, block_num_y); } template @@ -131,7 +139,7 @@ static void get_rows_cuda_float( cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; - const dim3 block_nums(ne10, block_num_y, ne11*ne12); + const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12); // strides in elements // const size_t s0 = nb0 / sizeof(dst_t); @@ -150,7 +158,7 @@ static void get_rows_cuda_float( /*ne10, ne11,*/ ne12, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, - s10, s11, s12/*, s13*/); + s10, s11, s12/*, s13*/, block_num_y); } template