diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 216f895922..fe55de4b91 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -819,9 +819,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } } - - - if (block_k != num_block_tiles_k) { // switch smem buffers each iteration diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index b242277eb0..0226e715ce 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -66,7 +66,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ // apply swizzle to the dst index - const unsigned int src_index = thread_row * src_stride + thread_col * 8; + const unsigned int src_index = thread_row * src_stride + start_k + thread_col * 8; unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); @@ -262,7 +262,7 @@ __device__ __forceinline__ void tileMemcpyLoadB( #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ - const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8; + const unsigned int src_index = thread_row * src_stride + start_k + block_k + thread_col * 8; if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && start_k+block_k+thread_col*8 < end_k){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index b5e7b18a2a..87abd015dc 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -44,7 +44,7 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int kw = 3, // create data int KW = kw, KH = kh, IC = ic, OC = oc; int IW = iw, IH = ih, N = 1; - srand(time(NULL)); + // srand(time(NULL)); // printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH); @@ -384,8 +384,8 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - // // for(int i = 26*38; i < 2*26*38; i++) { - // // for(int i = 0; i < conv2d_data.size(); i++) { + // for(int i = 26*38; i < 2*26*38; i++) { + // for(int i = 0; i < conv2d_data.size(); i++) { // float diff = fabs(im2col_data[i] - conv2d_data[i]); // // if(diff > 0.5) { // printf("(%7.3f, %7.3f, %.2f, %d) \n",