diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 360127b8d5..f646cf73b3 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -805,8 +805,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // prefetch the first block tile of A,B into shared memory half* A_block_gmem = input + (block_m * BM * A_stride); half* B_block_gmem = weight + (block_n * BN); - tileMemcpySwizzleA(A_block_gmem, A_block_smem, K); - tileMemcpySwizzle(B_block_gmem, B_block_smem, N); + tileMemcpySwizzleA(A_block_gmem, A_block_smem, K); + tileMemcpySwizzle(B_block_gmem, B_block_smem, N); // construct const pointers to warp tiles for use inside the inner loop @@ -819,16 +819,16 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, if (block_k != num_block_tiles_k) { - half* A_block_gmem = A + (block_m * BM_dim * A_stride) + (block_k * BK_dim); - half* B_block_gmem = B + (block_k * BK_dim * B_stride) + (block_n * BN_dim); - tileMemcpyLoad(A_block_gmem, A_gmem_cache_reg, K); - tileMemcpyLoad(B_block_gmem, B_gmem_cache_reg, N); + half* A_block_gmem = A + (block_m * BM * A_stride) + (block_k * BK); + half* B_block_gmem = B + (block_k * BK * B_stride) + (block_n * BN); + tileMemcpyLoad(A_block_gmem, A_gmem_cache_reg, K); + tileMemcpyLoad(B_block_gmem, B_gmem_cache_reg, N); } - half* A_warp_tile = A_block_smem + (warp_m * WM_dim * BK_dim); - half* B_warp_tile = B_block_smem + (warp_n * WN_dim); + half* A_warp_tile = A_block_smem + (warp_m * WM * BK); + half* B_warp_tile = B_block_smem + (warp_n * WN); - ldmatrix_a(A_warp_tile, A_register_); - ldmatrix_b(B_warp_tile, B_register_); + ldmatrix_a(A_warp_tile, A_register_); + ldmatrix_b(B_warp_tile, B_register_); // outer product between mma tiles #pragma unroll @@ -863,8 +863,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, B_block_smem = B_block_smem + BUFFER_SIZE * offset_direction; offset_direction = -1 * offset_direction; - tileMemcpySwizzleStoreA(A_gmem_cache_reg, A_block_smem); - tileMemcpySwizzleStore(B_gmem_cache_reg, B_block_smem); + tileMemcpySwizzleStoreA(A_gmem_cache_reg, A_block_smem); + tileMemcpySwizzleStore(B_gmem_cache_reg, B_block_smem); } }