diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cuh b/ggml/src/ggml-cuda/conv3d-implicit.cuh index 9cd7fe4e9b..37449f677e 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv3d-implicit.cuh @@ -104,7 +104,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); // TODO: move some checks outside of loop? - if (thread_row < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){ + if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){ dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); @@ -302,7 +302,7 @@ __device__ __forceinline__ void tileMemcpyLoadB( for (unsigned int i = 0; i < NUM_ITERS; i++){ const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8; // TODO : move some checks outside of the loop - if (thread_row < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){ + if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);