WIP: adding cp.async calls
This commit is contained in:
parent
11bd9806bf
commit
378bb8368e
|
|
@ -176,11 +176,25 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
|
|||
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
unsigned int smem_ptr;
|
||||
void *ptr = (void *)(dst);
|
||||
int src_in_bytes = thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k ? 16 : 0;
|
||||
asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 "
|
||||
"%0, smem_ptr; }\n"
|
||||
: "=r"(smem_ptr)
|
||||
: "l"(ptr));
|
||||
|
||||
asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_ptr),
|
||||
"l"(&src[src_index]),
|
||||
"n"(16), "r"(src_in_bytes));
|
||||
#else
|
||||
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
}else{ // read 4 halves
|
||||
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
}
|
||||
#endif
|
||||
thread_row += ROW_STEP;
|
||||
}
|
||||
#else
|
||||
|
|
@ -257,6 +271,19 @@ __device__ __forceinline__ unsigned int tileMemcpySwizzleA(
|
|||
// printf(" %u, %u, %u, %u, %lld, %d\n", i, curR, curS, curC, element_offset[i], valid?1:0);
|
||||
// }
|
||||
// if (valid && curC < end_k){
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
unsigned int smem_ptr;
|
||||
void *ptr = (void *)(dst);
|
||||
int src_in_bytes = valid ? 16 : 0;
|
||||
asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 "
|
||||
"%0, smem_ptr; }\n"
|
||||
: "=r"(smem_ptr)
|
||||
: "l"(ptr));
|
||||
|
||||
asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_ptr),
|
||||
"l"(&src[element_offset[i]+curC]),
|
||||
"n"(16), "r"(src_in_bytes));
|
||||
#else
|
||||
if (valid){
|
||||
// if(element_offset[i] >= 327680 || element_offset[i] < 0)
|
||||
// printf("%d, %d, %d, %d, %d, %d, %d, %d, %d \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y,
|
||||
|
|
@ -265,6 +292,7 @@ __device__ __forceinline__ unsigned int tileMemcpySwizzleA(
|
|||
} else{
|
||||
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
}
|
||||
#endif
|
||||
thread_row += ROW_STEP;
|
||||
}
|
||||
// #pragma unroll
|
||||
|
|
|
|||
Loading…
Reference in New Issue