diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index b0d8c17a50..f2c3d60998 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -176,11 +176,25 @@ __device__ __forceinline__ void tileMemcpySwizzleB( unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + unsigned int smem_ptr; + void *ptr = (void *)(dst); + int src_in_bytes = thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k ? 16 : 0; + asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 " + "%0, smem_ptr; }\n" + : "=r"(smem_ptr) + : "l"(ptr)); + + asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_ptr), + "l"(&src[src_index]), + "n"(16), "r"(src_in_bytes)); +#else if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){ dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); } +#endif thread_row += ROW_STEP; } #else @@ -257,6 +271,19 @@ __device__ __forceinline__ unsigned int tileMemcpySwizzleA( // printf(" %u, %u, %u, %u, %lld, %d\n", i, curR, curS, curC, element_offset[i], valid?1:0); // } // if (valid && curC < end_k){ +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + unsigned int smem_ptr; + void *ptr = (void *)(dst); + int src_in_bytes = valid ? 16 : 0; + asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 " + "%0, smem_ptr; }\n" + : "=r"(smem_ptr) + : "l"(ptr)); + + asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_ptr), + "l"(&src[element_offset[i]+curC]), + "n"(16), "r"(src_in_bytes)); +#else if (valid){ // if(element_offset[i] >= 327680 || element_offset[i] < 0) // printf("%d, %d, %d, %d, %d, %d, %d, %d, %d \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, @@ -265,6 +292,7 @@ __device__ __forceinline__ unsigned int tileMemcpySwizzleA( } else{ dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); } +#endif thread_row += ROW_STEP; } // #pragma unroll