WIP: adding cp.async calls

This commit is contained in:
bssrdf 2025-11-14 18:48:06 -05:00
parent 11bd9806bf
commit 378bb8368e
1 changed files with 28 additions and 0 deletions

View File

@ -176,11 +176,25 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
unsigned int smem_ptr;
void *ptr = (void *)(dst);
int src_in_bytes = thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k ? 16 : 0;
asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 "
"%0, smem_ptr; }\n"
: "=r"(smem_ptr)
: "l"(ptr));
asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_ptr),
"l"(&src[src_index]),
"n"(16), "r"(src_in_bytes));
#else
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[src_index])[0];
}else{ // read 4 halves
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
}
#endif
thread_row += ROW_STEP;
}
#else
@ -257,6 +271,19 @@ __device__ __forceinline__ unsigned int tileMemcpySwizzleA(
// printf(" %u, %u, %u, %u, %lld, %d\n", i, curR, curS, curC, element_offset[i], valid?1:0);
// }
// if (valid && curC < end_k){
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
unsigned int smem_ptr;
void *ptr = (void *)(dst);
int src_in_bytes = valid ? 16 : 0;
asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 "
"%0, smem_ptr; }\n"
: "=r"(smem_ptr)
: "l"(ptr));
asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_ptr),
"l"(&src[element_offset[i]+curC]),
"n"(16), "r"(src_in_bytes));
#else
if (valid){
// if(element_offset[i] >= 327680 || element_offset[i] < 0)
// printf("%d, %d, %d, %d, %d, %d, %d, %d, %d \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y,
@ -265,6 +292,7 @@ __device__ __forceinline__ unsigned int tileMemcpySwizzleA(
} else{
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
}
#endif
thread_row += ROW_STEP;
}
// #pragma unroll