fuse cast to float into conv epilogue; improve swizzling for output
This commit is contained in:
parent
d2d814c156
commit
a428feecdd
|
|
@ -681,11 +681,11 @@ __device__ __forceinline__ void ldmatrix_b(
|
|||
#endif
|
||||
}
|
||||
|
||||
template<const int BM, const int BN, const int BK, const int WM, const int WN,
|
||||
template<typename T, const int BM, const int BN, const int BK, const int WM, const int WN,
|
||||
const int WK, const int NUM_THREADS>
|
||||
static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
|
||||
const half * __restrict__ kernel,
|
||||
half * __restrict__ output,
|
||||
T * __restrict__ output,
|
||||
const param_t param) {
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||
|
||||
|
|
@ -828,27 +828,36 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
|
|||
uint32_t (®_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
|
||||
uint idx = output_sts_addr +
|
||||
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
|
||||
idx = idx ^ ((idx & 0b110000000000) >> 9);
|
||||
idx = idx ^ ((idx & 0b1110000000) >> 4);
|
||||
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
|
||||
dst_ptr[0] = reg_[0];
|
||||
dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx + 8 * BN / 2]);
|
||||
idx = (idx + 8 * BN / 2 ) ^ 0b010;
|
||||
dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
|
||||
dst_ptr[0] = reg_[1];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int subk = 0; subk < WN / 2; ++subk){
|
||||
for (int subk = 0; subk < WN / 4; ++subk){
|
||||
const uint row = m_idx + subk*2 + i * WN / 2;
|
||||
uint idx = output_lds_addr + subk*2; // + j*32*BN/2;
|
||||
idx = idx ^ ((idx & 0b110000000000) >> 9);
|
||||
idx = idx ^ ((idx & 0b1110000000) >> 4);
|
||||
for (int j = 0; j < 4; ++j){
|
||||
const uint row = m_idx + subk + i * WN / 2;
|
||||
const uint gemm_i = n_idx + j*32;
|
||||
const int n = fastdiv(gemm_i, param.PQZ_fastdiv);
|
||||
const int col = fastmodulo(gemm_i, param.PQZ_fastdiv);
|
||||
uint32_t dst_ptr = *(reinterpret_cast<uint32_t*>(&smemoutput[idx+j*32*BN/2]));
|
||||
half (&res_)[2] = reinterpret_cast<half(&)[2]>(dst_ptr);
|
||||
if(n < param.n && row < param.k && col < PQZ){
|
||||
const uint outOffset = (n * param.k + row) * PQZ + col;
|
||||
uint idx = output_lds_addr + subk + j*32*BN/2;
|
||||
idx = idx ^ ((idx & 0b1110000000) >> 4);
|
||||
output[outOffset] = smemoutput[idx];
|
||||
output[outOffset] = ggml_cuda_cast<T>(res_[0]);
|
||||
}
|
||||
if(n < param.n && row+1 < param.k && col < PQZ){
|
||||
const uint outOffset = (n * param.k + row + 1) * PQZ + col;
|
||||
output[outOffset] = ggml_cuda_cast<T>(res_[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -924,15 +933,17 @@ static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
|
|||
|
||||
ne = P.c * P.r * P.s * P.t * P.k;
|
||||
ne01 = P.r * P.s * P.t;
|
||||
ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id), ne);
|
||||
dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
|
||||
(ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
|
||||
(ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
|
||||
NCHW2NHWC<half, half><<<dimGrid1, dimBlock, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
|
||||
ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id));
|
||||
if(ne01 > 1){
|
||||
kernel_f16.alloc(ne);
|
||||
dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
|
||||
(ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
|
||||
(ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
|
||||
NCHW2NHWC<half, half><<<dimGrid1, dimBlock, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
|
||||
}
|
||||
|
||||
const half *X_H = input_f16.get();
|
||||
const half *K_H = kernel_f16.get();
|
||||
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), P.k * P.Od *P.Oh * P.Ow * P.n);
|
||||
const half *K_H = ne01 == 1 ? K_D : kernel_f16.get();
|
||||
|
||||
constexpr unsigned int BM_dim = 256;
|
||||
constexpr unsigned int BN_dim = 256;
|
||||
|
|
@ -952,16 +963,14 @@ static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
|
|||
constexpr unsigned int NumThreads = ThreadsM * ThreadsN;
|
||||
const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
|
||||
|
||||
cudaFuncSetAttribute(conv3d_implicit_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, NumThreads>,
|
||||
cudaFuncSetAttribute(conv3d_implicit_kernel<float, BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, NumThreads>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
|
||||
dim3 gridDim(BlocksN, BlocksM);
|
||||
dim3 blockDim(ThreadsN, ThreadsM);
|
||||
|
||||
conv3d_implicit_kernel<BM_dim, BN_dim, BK_dim,
|
||||
conv3d_implicit_kernel<float, BM_dim, BN_dim, BK_dim,
|
||||
WM_dim, WN_dim, WK_dim, NumThreads>
|
||||
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow *P.Od * P.n, st);
|
||||
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_D, P);
|
||||
} else{
|
||||
conv3d_implicit_cuda<half, 1>(X_D, K_D, Y_D, P, st);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -350,9 +350,9 @@ int main(void)
|
|||
{
|
||||
ggml_time_init();
|
||||
std::vector<std::tuple<int, int, int, int, int, int, int, int>> configs = {
|
||||
std::make_tuple(1,2,16,32,4,3,3,3),
|
||||
// std::make_tuple(1,2,16,32,4,3,3,3),
|
||||
// std::make_tuple(320,1280,26,38,8,3,3,3),
|
||||
// std::make_tuple(1280,1280,26,38,8,3,3,3),
|
||||
std::make_tuple(1280,1280,26,38,8,3,3,3),
|
||||
// std::make_tuple(320,1280,52,76,8,3,3,3),
|
||||
// std::make_tuple(1280,1280,52,76,8,3,3,3),
|
||||
// std::make_tuple(320,1280,104,152,8,3,3,3),
|
||||
|
|
@ -380,7 +380,7 @@ int main(void)
|
|||
// fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f);
|
||||
|
||||
|
||||
int iterations = 20;
|
||||
int iterations = 0;
|
||||
|
||||
double run_time0;
|
||||
std::vector<float> im2col_data = compute_graph(model, allocr, build_graph_0, iterations,
|
||||
|
|
|
|||
Loading…
Reference in New Issue