move some register to const memory space
This commit is contained in:
parent
b015e4b7dc
commit
0cb1ff419a
|
|
@ -781,13 +781,13 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
constexpr unsigned int MMA_M = 16;
|
||||
constexpr unsigned int MMA_N = 8;
|
||||
|
||||
const unsigned int K = param.c;
|
||||
const uint inChannelOffset = param.c * param.w;
|
||||
const uint weightKOffset = param.c * param.r * param.s;
|
||||
// const unsigned int K = param.c;
|
||||
// const uint inChannelOffset = param.c * param.w;
|
||||
// const uint weightKOffset = param.c * param.r * param.s;
|
||||
|
||||
const unsigned int PQ = param.Ow * param.Oh;
|
||||
const unsigned int KPQ = param.k * PQ;
|
||||
const unsigned int NKPQ = param.n * KPQ;
|
||||
// const unsigned int PQ = param.Ow * param.Oh;
|
||||
// const unsigned int KPQ = param.k * PQ;
|
||||
// const unsigned int NKPQ = param.n * KPQ;
|
||||
|
||||
// loop bounds, constexpr where possible allows for loop unrolling
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
|
|
@ -799,9 +799,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N;
|
||||
const unsigned int z = blockIdx.z;
|
||||
|
||||
const unsigned int ks = (ksplit > 0) ? (K + ksplit - 1) / ksplit : K;
|
||||
const unsigned int ks = (ksplit > 0) ? (param.c + ksplit - 1) / ksplit : param.c;
|
||||
const unsigned int start_k = (ksplit > 0) ? z * ks : 0;
|
||||
const unsigned int end_k = min(start_k + ks, K);
|
||||
const unsigned int end_k = min(start_k + ks, param.c);
|
||||
const unsigned int num_block_tiles_k = (ks + (BK-1)) / BK;
|
||||
const unsigned int num_block_tiles_krs = num_block_tiles_k * param.r * param.s;
|
||||
|
||||
|
|
@ -888,11 +888,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
// prefetch the first block tile of A,B into shared memory
|
||||
|
||||
const half* A_block_gmem = input;
|
||||
const half* B_block_gmem = kernel + block_n * BN * weightKOffset;
|
||||
const half* B_block_gmem = kernel + block_n * BN * param.weightKOffset;
|
||||
|
||||
unsigned int curC = tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, 0, 0, masks_a, element_offset_a,
|
||||
thread_idx, start_k, end_k, inChannelOffset, param);
|
||||
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, 0, 0, start_k, end_k, weightKOffset, param);
|
||||
thread_idx, start_k, end_k, param);
|
||||
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, 0, 0, start_k, end_k, param);
|
||||
|
||||
int offset_direction = 1;
|
||||
unsigned int block_k = 0;
|
||||
|
|
@ -945,8 +945,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
if (block_krs != num_block_tiles_krs){
|
||||
curC = tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, r, s,
|
||||
masks_a, element_offset_a, thread_idx, block_k * BK,
|
||||
start_k, end_k, curC, inChannelOffset, param);
|
||||
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, weightKOffset, param);
|
||||
start_k, end_k, curC, param);
|
||||
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, param);
|
||||
}
|
||||
|
||||
half* A_warp_tile = A_block_smem + A_warp_tile_offset;
|
||||
|
|
@ -1064,12 +1064,12 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
uint32_t dst_ptr = *(reinterpret_cast<uint32_t*>(&smemoutput[idx+j*16*BN])); // 32*BN/2 = 16*BN
|
||||
half (&res_)[2] = reinterpret_cast<half(&)[2]>(dst_ptr);
|
||||
if (n < param.n && row < param.k && col < PQ) {
|
||||
const uint outOffset = ((ksplit > 0) ? z * NKPQ : 0) + n * KPQ + row * PQ + col;
|
||||
if (n < param.n && row < param.k && col < param.PQ) {
|
||||
const uint outOffset = ((ksplit > 0) ? z * param.NKPQ : 0) + n * param.KPQ + row * param.PQ + col;
|
||||
output[outOffset] = ggml_cuda_cast<T>(res_[0]);
|
||||
}
|
||||
if (n < param.n && row+1 < param.k && col < PQ) {
|
||||
const uint outOffset = ((ksplit > 0) ? z * NKPQ : 0) + n * KPQ + (row+1) * PQ + col;
|
||||
if (n < param.n && row+1 < param.k && col < param.PQ) {
|
||||
const uint outOffset = ((ksplit > 0) ? z * param.NKPQ : 0) + n * param.KPQ + (row+1) * param.PQ + col;
|
||||
output[outOffset] = ggml_cuda_cast<T>(res_[1]);
|
||||
}
|
||||
}
|
||||
|
|
@ -1389,6 +1389,14 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
const uint B = input->ne[3]; // n_batches
|
||||
|
||||
|
||||
int64_t pp[3];
|
||||
// const unsigned int K = param.c;
|
||||
// const uint inChannelOffset = param.c * param.w;
|
||||
// const uint weightKOffset = param.c * param.r * param.s;
|
||||
// const unsigned int PQ = param.Ow * param.Oh;
|
||||
// const unsigned int KPQ = param.k * PQ;
|
||||
// const unsigned int NKPQ = param.n * KPQ;
|
||||
|
||||
|
||||
param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW,
|
||||
init_fastdiv_values(KW*IC),
|
||||
|
|
@ -1396,7 +1404,13 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
init_fastdiv_values(IC),
|
||||
init_fastdiv_values(KW*KH),
|
||||
init_fastdiv_values(KW),
|
||||
init_fastdiv_values(OW*OH)};
|
||||
init_fastdiv_values(OW*OH),
|
||||
pp[0], pp[1], pp[2],
|
||||
IC*IW,
|
||||
IC*KW*KH,
|
||||
OW*OH,
|
||||
OC*OW*OH,
|
||||
B*OC*OW*OH};
|
||||
|
||||
if (kernel->type == GGML_TYPE_F16) {
|
||||
conv2d_implicit_cuda_f16(ctx, X_D, (half *) K_D, Y_D, cc, params, st);
|
||||
|
|
|
|||
|
|
@ -24,6 +24,13 @@ typedef struct{
|
|||
uint3 S_fastdiv;
|
||||
uint3 OHOW_fastdiv;
|
||||
int64_t inc_next[3];
|
||||
// unsigned int K;
|
||||
unsigned int inChannelOffset;
|
||||
unsigned int weightKOffset;
|
||||
unsigned int PQ;
|
||||
unsigned int KPQ;
|
||||
unsigned int NKPQ;
|
||||
|
||||
} param_t;
|
||||
|
||||
|
||||
|
|
@ -125,7 +132,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
|
|||
const unsigned int curS,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
const unsigned int src_stride,
|
||||
// const unsigned int src_stride,
|
||||
param_t param
|
||||
){
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
|
||||
|
|
@ -161,7 +168,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
|
|||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
// apply swizzle to the dst index
|
||||
const unsigned int src_index = thread_row * src_stride + ki;
|
||||
const unsigned int src_index = thread_row * param.weightKOffset + ki;
|
||||
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
|
|
@ -195,7 +202,6 @@ __device__ __forceinline__ unsigned int tileMemcpySwizzleA(
|
|||
const unsigned int thread_idx,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
const unsigned int inChannelOffset,
|
||||
param_t param
|
||||
)
|
||||
{
|
||||
|
|
@ -300,7 +306,7 @@ __device__ __forceinline__ unsigned int tileMemcpyLoadA(
|
|||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
unsigned int oldC,
|
||||
const unsigned int inChannelOffset,
|
||||
// const unsigned int inChannelOffset,
|
||||
param_t param
|
||||
){
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
|
||||
|
|
@ -396,7 +402,7 @@ __device__ __forceinline__ void tileMemcpyLoadB(
|
|||
const unsigned int block_k,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
const unsigned int src_stride,
|
||||
// const unsigned int src_stride,
|
||||
param_t param
|
||||
){
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
|
||||
|
|
@ -426,7 +432,7 @@ __device__ __forceinline__ void tileMemcpyLoadB(
|
|||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
const unsigned int src_index = thread_row * src_stride + ki;
|
||||
const unsigned int src_index = thread_row * param.weightKOffset + ki;
|
||||
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
}else{ // read 4 halves
|
||||
|
|
|
|||
Loading…
Reference in New Issue