added split-k mode to tensor core path
This commit is contained in:
parent
a428feecdd
commit
15daa5a6a8
|
|
@ -13,18 +13,19 @@ constexpr uint WARPSIZE = 32;
|
|||
|
||||
|
||||
//currently not use; in future for split-k kernels
|
||||
// static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) {
|
||||
// const int row = blockIdx.x;
|
||||
// const int col = threadIdx.x;
|
||||
template <typename src_T, typename dst_T>
|
||||
static __global__ void reduce_f32(const src_T * __restrict__ x, dst_T * __restrict__ dst, const int ncols, const int nrows) {
|
||||
const int row = blockIdx.x;
|
||||
const int col = threadIdx.x;
|
||||
|
||||
// float sum = 0.0f;
|
||||
// if (row * blockDim.x + col < ncols) {
|
||||
// for (int i = 0; i < nrows; ++i){
|
||||
// sum += x[i * ncols + row * blockDim.x + col];
|
||||
// }
|
||||
// dst[row * blockDim.x + col] = sum;
|
||||
// }
|
||||
// }
|
||||
float sum = 0.0f;
|
||||
if (row * blockDim.x + col < ncols) {
|
||||
for (int i = 0; i < nrows; ++i){
|
||||
sum += ggml_cuda_cast<float>(x[i * ncols + row * blockDim.x + col]);
|
||||
}
|
||||
dst[row * blockDim.x + col] = ggml_cuda_cast<dst_T>(sum);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename src_T, typename dst_T>
|
||||
static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){
|
||||
|
|
@ -682,7 +683,7 @@ __device__ __forceinline__ void ldmatrix_b(
|
|||
}
|
||||
|
||||
template<typename T, const int BM, const int BN, const int BK, const int WM, const int WN,
|
||||
const int WK, const int NUM_THREADS>
|
||||
const int WK, const int ksplit, const int NUM_THREADS>
|
||||
static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
|
||||
const half * __restrict__ kernel,
|
||||
T * __restrict__ output,
|
||||
|
|
@ -699,12 +700,19 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
|
|||
const uint inChannelOffset = param.c * param.w;
|
||||
const uint inDepthOffset = param.h * param.c * param.w;
|
||||
const uint inNOffset = param.c * param.w * param.h * param.d;
|
||||
const unsigned int z = blockIdx.z;
|
||||
|
||||
// loop bounds, constexpr where possible allows for loop unrolling
|
||||
constexpr unsigned int mma_tiles_per_warp_k = 4;
|
||||
constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M;
|
||||
constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N;
|
||||
const unsigned int num_block_tiles_k = (K + (BK-1)) / BK;
|
||||
|
||||
const unsigned int ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset;
|
||||
const unsigned int start_k = (ksplit > 0) ? z * ks : 0;
|
||||
const unsigned int end_k = min(start_k + ks, weightKOffset);
|
||||
const unsigned int num_block_tiles_k = (ks + (BK-1)) / BK;
|
||||
|
||||
// const unsigned int num_block_tiles_k = (K + (BK-1)) / BK;
|
||||
|
||||
// calculate block/warp indices
|
||||
const unsigned int block_m = blockIdx.y;
|
||||
|
|
@ -750,8 +758,8 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
|
|||
|
||||
const half* A_block_gmem = input;
|
||||
const half* B_block_gmem = kernel + block_n * BN * weightKOffset;
|
||||
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, inNOffset, inDepthOffset, inChannelOffset, param);
|
||||
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, weightKOffset, param);
|
||||
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, start_k, end_k, inNOffset, inDepthOffset, inChannelOffset, param);
|
||||
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, start_k, end_k, weightKOffset, param);
|
||||
|
||||
int offset_direction = 1;
|
||||
|
||||
|
|
@ -761,9 +769,10 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
|
|||
if (block_k != num_block_tiles_k){
|
||||
const half* A_block_gmem = input;
|
||||
const half* B_block_gmem = kernel + (block_n * BN * weightKOffset);
|
||||
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK,
|
||||
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK, start_k, end_k,
|
||||
inNOffset, inDepthOffset, inChannelOffset, param);
|
||||
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param);
|
||||
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, block_k * BK, start_k, end_k,
|
||||
weightKOffset, param);
|
||||
}
|
||||
half* A_warp_tile = A_block_smem + (warp_m * WM * BK);
|
||||
half* B_warp_tile = B_block_smem + (warp_n * WN * BK);
|
||||
|
|
@ -852,12 +861,28 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
|
|||
uint32_t dst_ptr = *(reinterpret_cast<uint32_t*>(&smemoutput[idx+j*32*BN/2]));
|
||||
half (&res_)[2] = reinterpret_cast<half(&)[2]>(dst_ptr);
|
||||
if(n < param.n && row < param.k && col < PQZ){
|
||||
const uint outOffset = (n * param.k + row) * PQZ + col;
|
||||
// if constexpr (ksplit > 0) {
|
||||
// const uint outOffset = (n * param.k + row) * PQZ + col;
|
||||
// output[outOffset] = ggml_cuda_cast<T>(res_[0]);
|
||||
// } else {
|
||||
// const uint outOffset = (n * param.k + row) * PQZ + col;
|
||||
// output[outOffset] = ggml_cuda_cast<T>(res_[0]);
|
||||
// }
|
||||
const uint outOffset = ksplit > 0 ? (z * param.n * param.k + n * param.k + row) * PQZ + col :
|
||||
(n * param.k + row) * PQZ + col;
|
||||
output[outOffset] = ggml_cuda_cast<T>(res_[0]);
|
||||
}
|
||||
if(n < param.n && row+1 < param.k && col < PQZ){
|
||||
const uint outOffset = (n * param.k + row + 1) * PQZ + col;
|
||||
const uint outOffset = ksplit > 0 ? (z * param.n * param.k + n * param.k + row+1) * PQZ + col :
|
||||
(n * param.k + row+1) * PQZ + col;
|
||||
output[outOffset] = ggml_cuda_cast<T>(res_[1]);
|
||||
// if constexpr (ksplit > 0) {
|
||||
// const uint outOffset = (n * param.k + row) * PQZ + col;
|
||||
// output[outOffset] = ggml_cuda_cast<T>(res_[0]);
|
||||
// } else {
|
||||
// const uint outOffset = (n * param.k + row + 1) * PQZ + col;
|
||||
// output[outOffset] = ggml_cuda_cast<T>(res_[1]);
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -914,6 +939,33 @@ static void conv3d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D,
|
|||
WNITER, TM, TN, NUM_THREADS, 1, false, 0><<<grid, thblock, 0, st>>>(X_D, K_D, Y_D, P);
|
||||
}
|
||||
|
||||
template<const int BM, const int BN, const int BK,
|
||||
const int WM, const int WN, const int WK, const int ksplit,
|
||||
const unsigned int ThreadsM, const unsigned int ThreadsN,
|
||||
const int NUM_THREADS>
|
||||
static void launch_conv3d_implicit_split_kernel(ggml_backend_cuda_context & ctx, const half *X_H, const half *K_H, float *Y_D,
|
||||
const unsigned int BlocksM, const unsigned int BlocksN,
|
||||
const unsigned int shmem_bytes,
|
||||
const param_t P, cudaStream_t st){
|
||||
|
||||
int id = ggml_cuda_get_device();
|
||||
|
||||
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), ksplit * P.k * P.Od * P.Oh * P.Ow * P.n);
|
||||
cudaFuncSetAttribute(conv3d_implicit_kernel<half, BM, BN, BK, WM, WN, WK, ksplit, NUM_THREADS>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
|
||||
dim3 gridDim(BlocksN, BlocksM, ksplit);
|
||||
dim3 blockDim(ThreadsN, ThreadsM);
|
||||
|
||||
conv3d_implicit_kernel<half, BM, BN, BK, WM, WN, WK, ksplit, NUM_THREADS>
|
||||
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
|
||||
|
||||
const unsigned int nrows = P.n * P.k * P.Oh * P.Ow * P.Od;
|
||||
const unsigned int blockx = (nrows + 511) / 512;
|
||||
const dim3 block_nums(blockx, 1, 1);
|
||||
const dim3 block_dims(512, 1, 1);
|
||||
reduce_f32<half, float><<<block_nums, block_dims, 0, st>>>(Y_H.get(), Y_D, nrows, ksplit);
|
||||
}
|
||||
|
||||
static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) {
|
||||
|
||||
if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1 || P.t > 1)) {
|
||||
|
|
@ -956,6 +1008,9 @@ static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
|
|||
constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M;
|
||||
constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N;
|
||||
constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K;
|
||||
|
||||
static_assert(WN_dim % 4 == 0, "final output requires this to be bank conflicts free");
|
||||
|
||||
const unsigned int BlocksM = (P.n * P.Oh * P.Ow * P.Od + BM_dim - 1) / BM_dim;
|
||||
const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim;
|
||||
constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M;
|
||||
|
|
@ -963,13 +1018,73 @@ static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
|
|||
constexpr unsigned int NumThreads = ThreadsM * ThreadsN;
|
||||
const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
|
||||
|
||||
cudaFuncSetAttribute(conv3d_implicit_kernel<float, BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, NumThreads>,
|
||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||
// if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) {
|
||||
if (BlocksM * BlocksN < 2*(unsigned int)nsm){
|
||||
int j, max_remaining_waves = -1, candidate = -1;
|
||||
int ks = min(12, nsm / (BlocksM * BlocksN));
|
||||
if (ks < 2 && (BlocksM * BlocksN) % nsm < nsm*4/5)
|
||||
ks = 12;
|
||||
for (j = 2; j <= ks; j++){
|
||||
const int remainder = (BlocksM * BlocksN * j) % nsm;
|
||||
if ((P.c * P.r * P.s * P.t) % (8*j) == 0){
|
||||
if (remainder == 0) {
|
||||
candidate = j;
|
||||
max_remaining_waves = 0;
|
||||
break;
|
||||
} else if (remainder > max_remaining_waves) {
|
||||
max_remaining_waves = remainder;
|
||||
candidate = j;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(candidate != -1){
|
||||
j = candidate;
|
||||
if (j == 2) {
|
||||
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 2,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 3) {
|
||||
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 3,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 4) {
|
||||
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 4,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 5) {
|
||||
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 5,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 6) {
|
||||
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 6,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 7) {
|
||||
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 7,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 8) {
|
||||
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 8,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 9) {
|
||||
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 9,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 10) {
|
||||
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 10,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 11) {
|
||||
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 11,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 12) {
|
||||
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 12,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
cudaFuncSetAttribute(conv3d_implicit_kernel<float, BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 0, NumThreads>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
|
||||
dim3 gridDim(BlocksN, BlocksM);
|
||||
dim3 blockDim(ThreadsN, ThreadsM);
|
||||
|
||||
conv3d_implicit_kernel<float, BM_dim, BN_dim, BK_dim,
|
||||
WM_dim, WN_dim, WK_dim, NumThreads>
|
||||
WM_dim, WN_dim, WK_dim, 0, NumThreads>
|
||||
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_D, P);
|
||||
} else{
|
||||
conv3d_implicit_cuda<half, 1>(X_D, K_D, Y_D, P, st);
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@ unsigned int NUM_THREADS>
|
|||
__device__ __forceinline__ void tileMemcpySwizzleB(
|
||||
const half* src,
|
||||
half* dst,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
const unsigned int src_stride,
|
||||
param_t param
|
||||
){
|
||||
|
|
@ -90,7 +92,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
|
|||
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
|
||||
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
|
||||
const unsigned int kidx = thread_col*8;
|
||||
const unsigned int kidx = start_k + thread_col*8;
|
||||
const int4 curIdx = inputIndices<0>(kidx, param);
|
||||
const int curC = curIdx.x;
|
||||
const int curT = curIdx.y;
|
||||
|
|
@ -104,7 +106,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
|
|||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
// TODO: move some checks outside of loop?
|
||||
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){
|
||||
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c && kidx < end_k){
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
}else{ // read 4 halves
|
||||
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
|
|
@ -127,7 +129,8 @@ unsigned int NUM_THREADS>
|
|||
__device__ __forceinline__ void tileMemcpySwizzleA(
|
||||
const half* src,
|
||||
half* dst,
|
||||
// const unsigned int src_stride,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
const unsigned int inNOffset,
|
||||
const unsigned int inDepthOffset,
|
||||
const unsigned int inChannelOffset,
|
||||
|
|
@ -157,7 +160,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
|||
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
|
||||
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
|
||||
|
||||
const unsigned int kidx = thread_col*8;
|
||||
const unsigned int kidx = start_k+thread_col*8;
|
||||
const int4 curIdx = inputIndices<0>(kidx, param);
|
||||
|
||||
#pragma unroll
|
||||
|
|
@ -177,7 +180,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
|||
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c){
|
||||
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c && kidx < end_k){
|
||||
int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC;
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[n * inNOffset + inOffsetTmp])[0];
|
||||
} else{
|
||||
|
|
@ -203,6 +206,8 @@ __device__ __forceinline__ void tileMemcpyLoadA(
|
|||
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
|
||||
// const unsigned int src_stride,
|
||||
const unsigned int block_k,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
const unsigned int inNOffset,
|
||||
const unsigned int inDepthOffset,
|
||||
const unsigned int inChannelOffset,
|
||||
|
|
@ -227,7 +232,7 @@ __device__ __forceinline__ void tileMemcpyLoadA(
|
|||
// compile time check that we provided the right amount of registers for storage
|
||||
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||
|
||||
const unsigned int kidx = block_k + thread_col*8;
|
||||
const unsigned int kidx = start_k + block_k + thread_col*8;
|
||||
const int4 curIdx = inputIndices<0>(kidx, param);
|
||||
|
||||
#pragma unroll
|
||||
|
|
@ -243,7 +248,8 @@ __device__ __forceinline__ void tileMemcpyLoadA(
|
|||
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
|
||||
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
|
||||
const int curC = curIdx.x;
|
||||
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c){
|
||||
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d
|
||||
&& curC < param.c && kidx < end_k){
|
||||
int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC;
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[n * inNOffset + inOffsetTmp])[0];
|
||||
} else{
|
||||
|
|
@ -270,6 +276,8 @@ __device__ __forceinline__ void tileMemcpyLoadB(
|
|||
const half* src,
|
||||
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
|
||||
const unsigned int block_k,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
const unsigned int src_stride,
|
||||
param_t param
|
||||
){
|
||||
|
|
@ -292,7 +300,7 @@ __device__ __forceinline__ void tileMemcpyLoadB(
|
|||
// compile time check that we provided the right amount of registers for storage
|
||||
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||
|
||||
const unsigned int kidx = block_k + thread_col*8;
|
||||
const unsigned int kidx = start_k + block_k + thread_col*8;
|
||||
const int4 curIdx = inputIndices<0>(kidx, param);
|
||||
const int curC = curIdx.x;
|
||||
const int curT = curIdx.y;
|
||||
|
|
@ -300,9 +308,10 @@ __device__ __forceinline__ void tileMemcpyLoadB(
|
|||
const int curS = curIdx.w;
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8;
|
||||
const unsigned int src_index = thread_row * src_stride + kidx;
|
||||
// TODO : move some checks outside of the loop
|
||||
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){
|
||||
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t
|
||||
&& curC < param.c && kidx < end_k){
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
}else{ // read 4 halves
|
||||
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
|
|
|
|||
|
|
@ -350,9 +350,9 @@ int main(void)
|
|||
{
|
||||
ggml_time_init();
|
||||
std::vector<std::tuple<int, int, int, int, int, int, int, int>> configs = {
|
||||
// std::make_tuple(1,2,16,32,4,3,3,3),
|
||||
std::make_tuple(1,2,16,32,4,3,3,3),
|
||||
// std::make_tuple(320,1280,26,38,8,3,3,3),
|
||||
std::make_tuple(1280,1280,26,38,8,3,3,3),
|
||||
// std::make_tuple(1280,1280,26,38,8,3,3,3),
|
||||
// std::make_tuple(320,1280,52,76,8,3,3,3),
|
||||
// std::make_tuple(1280,1280,52,76,8,3,3,3),
|
||||
// std::make_tuple(320,1280,104,152,8,3,3,3),
|
||||
|
|
|
|||
Loading…
Reference in New Issue