diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 1a80901409..917f3a6b1e 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -871,6 +871,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, prepareIteratorA(thread_row, masks_a, element_offset_a, param); + unsigned int iter_src_idx = thread_row * param.weightKOffset; + unsigned int iter_dst_idx = thread_row * TILE_COLS_VECTORIZED + thread_col; + unsigned int krow_idx = thread_row + blockIdx.x * BN; + const int ITER_SRC_STEPS = ROW_STEP * param.weightKOffset; + // prefetch the first block tile of A,B into shared memory @@ -923,11 +928,13 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, if (block_krs != num_block_tiles_krs) { #ifdef CP_ASYNC_AVAILABLE curC = tileMemcpyAsyncLoadA(A_block_gmem, SA2, r, s, - masks_a, element_offset_a, thread_row, thread_col, block_k * BK, + masks_a, element_offset_a, thread_row, thread_col, + iter_dst_idx, block_k * BK, start_k, end_k, curC, param); element_offset_b = (r*param.s+s)*param.c + curC; tileMemcpyAsyncLoadB(B_block_gmem, SB2, r, s, curC, element_offset_b, block_k * BK, - start_k, end_k, thread_row, thread_col, param); + start_k, end_k, thread_row, thread_col, + iter_src_idx, iter_dst_idx, krow_idx, ITER_SRC_STEPS,param); asm volatile("cp.async.commit_group;\n" ::); #else curC = tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, r, s, diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 409c050c89..6df8478a47 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -343,6 +343,7 @@ __device__ __forceinline__ unsigned int tileMemcpyAsyncLoadA( const int64_t element_offset[], unsigned int thread_row, const unsigned int thread_col, + unsigned int iter_idx, const unsigned int block_k, const unsigned int start_k, const unsigned int end_k, @@ -369,7 +370,6 @@ __device__ __forceinline__ unsigned int tileMemcpyAsyncLoadA( if (curC > oldC) clear_mask(masks, curC >= end_k); - unsigned int iter_idx = thread_row * TILE_COLS_VECTORIZED + thread_col; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++) { bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); @@ -393,6 +393,7 @@ __device__ __forceinline__ unsigned int tileMemcpyAsyncLoadA( GGML_UNUSED(element_offset); GGML_UNUSED(thread_row); GGML_UNUSED(thread_col); + GGML_UNUSED(iter_idx); GGML_UNUSED(oldC); GGML_UNUSED(param); NO_DEVICE_CODE; @@ -480,6 +481,10 @@ __device__ __forceinline__ void tileMemcpyAsyncLoadB( const unsigned int end_k, unsigned int thread_row, const unsigned int thread_col, + unsigned int iter_src_idx, + unsigned int iter_dst_idx, + unsigned int krow_idx, + const int ITER_SRC_STEPS, param_t param ) { @@ -500,10 +505,7 @@ __device__ __forceinline__ void tileMemcpyAsyncLoadB( // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - unsigned int iter_src_idx = thread_row * param.weightKOffset + ki; - unsigned int iter_dst_idx = thread_row * TILE_COLS_VECTORIZED + thread_col; - unsigned int krow_idx = thread_row + blockIdx.x * TILE_ROWS; - const int ITER_SRC_STEPS = ROW_STEP * param.weightKOffset; + iter_src_idx += ki; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++) { @@ -529,6 +531,10 @@ __device__ __forceinline__ void tileMemcpyAsyncLoadB( GGML_UNUSED(end_k); GGML_UNUSED(thread_row); GGML_UNUSED(thread_col); + GGML_UNUSED(iter_src_idx); + GGML_UNUSED(iter_dst_idx); + GGML_UNUSED(krow_idx); + GGML_UNUSED(ITER_SRC_STEPS); GGML_UNUSED(param); NO_DEVICE_CODE; #endif