diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index d204807a2f..5bb5cd7cad 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -811,7 +811,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, constexpr unsigned int TILE_COLS_VECTORIZED = BK / 8; constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; constexpr unsigned int A_K_STRID = BM / ROW_STEP; - constexpr unsigned int B_K_STRID = BN / ROW_STEP; + // constexpr unsigned int B_K_STRID = BN / ROW_STEP; unsigned int masks_a[A_K_STRID][2]; int64_t element_offset_a[A_K_STRID]; @@ -1263,9 +1263,9 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa // if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) { if (BlocksM * BlocksN < 2*(unsigned int)nsm){ int j, max_remaining_waves = -1, candidate = -1; - int ks = min(16, nsm / (BlocksM * BlocksN)); + int ks = min(20, nsm / (BlocksM * BlocksN)); if (ks < 2 && (BlocksM * BlocksN) % nsm < nsm*4/5) - ks = 16; + ks = 20; for (j = 2; j <= ks; j++){ const int remainder = (BlocksM * BlocksN * j) % nsm; // if ((P.c * P.r * P.s) % (8*j) == 0){ @@ -1328,7 +1328,20 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa } else if (j == 16) { launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 17) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 18) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 19) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 20) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } + return; } } @@ -1395,7 +1408,7 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const uint B = input->ne[3]; // n_batches - int64_t pp[3]; + int64_t pp[3] = {0}; // const unsigned int K = param.c; // const uint inChannelOffset = param.c * param.w; // const uint weightKOffset = param.c * param.r * param.s; diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 9f817a0078..924b678b81 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -37,7 +37,8 @@ typedef struct{ /// Clears the predicates template -__host__ __device__ void clear_mask(unsigned int masks_[][2], bool clear = true) { +// __host__ __device__ void clear_mask(unsigned int masks_[][2], bool clear = true) { +__device__ void clear_mask(unsigned int masks_[][2], bool clear = true) { #pragma unroll for (int s = 0; s < K_STRID; ++s) { @@ -47,7 +48,8 @@ __host__ __device__ void clear_mask(unsigned int masks_[][2], bool clear = true) } template -__host__ __device__ void add_byte_offset(int64_t element_offset[], const int64_t offset){ +// __host__ __device__ void add_byte_offset(int64_t element_offset[], const int64_t offset){ +__device__ void add_byte_offset(int64_t element_offset[], const int64_t offset){ #pragma unroll for (int s = 0; s < K_STRID; ++s) { element_offset[s] += offset; @@ -66,7 +68,7 @@ __device__ void prepareIteratorA(unsigned int thread_row, int offset_p[A_K_STRID]; int offset_q[A_K_STRID]; - constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + // constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; // const unsigned int chw = param.c * param.h * param.w; @@ -436,15 +438,22 @@ __device__ __forceinline__ void tileMemcpyLoadB( const unsigned int curC = start_k+block_k+thread_col*8; const unsigned int ki = (curR*param.s+curS)*param.c + curC; + unsigned int iter_idx = thread_row * param.weightKOffset + ki; + unsigned int krow_idx = thread_row + blockIdx.x * TILE_ROWS; + const int ITER_STEPS = ROW_STEP * param.weightKOffset; + #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ - const unsigned int src_index = thread_row * param.weightKOffset + ki; - if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){ + // const unsigned int src_index = thread_row * param.weightKOffset + ki; + const unsigned int src_index = iter_idx; + // if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){ + if (krow_idx < param.k && curC < end_k){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); } - thread_row += ROW_STEP; + krow_idx += ROW_STEP; + iter_idx += ITER_STEPS; } #else GGML_UNUSED(src); @@ -492,21 +501,24 @@ __device__ __forceinline__ void tileMemcpySwizzleStore( // to cover the whole tile constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + constexpr unsigned int ITER_STEPS = ROW_STEP * TILE_COLS_VECTORIZED; // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; // const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + unsigned int iter_idx = thread_row * TILE_COLS_VECTORIZED + thread_col; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++) { // apply swizzle to the dst index - unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; + // unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; + unsigned int dst_index = iter_idx; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); dst_float4[dst_index] = src_reg[i]; - thread_row += ROW_STEP; + iter_idx += ITER_STEPS; } #else GGML_UNUSED(src_reg);