fixed a bug now all test cases passed

This commit is contained in:
bssrdf 2025-11-03 08:46:17 -05:00
parent 3308ccef91
commit 2357922a2f
3 changed files with 76 additions and 209 deletions

View File

@ -163,7 +163,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
const uint inKOffset = start_k + innerColA * 4;
#pragma unroll
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
const unsigned int gemm_i = bx * BM + innerRowA + offset;
const unsigned int gemm_i = bx * BM + innerRowA + offset;
// int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQZ : z;
int n = (ksplit > 0) ? fastdiv(gemm_i, param.PQZ_fastdiv) : z;
const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv);
@ -173,26 +173,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0;
int inOffset = n * inNOffset;
if(vec_load){
// const uint cur0 = fastdiv(inKOffset,
// layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
// const uint cur0_res = fastmodulo(inKOffset,
// layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
// const uint cur1 = fastdiv(cur0_res,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
// const uint cur1_res = fastmodulo(cur0_res,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
// const uint cur2 = fastdiv(cur1_res,
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
// const uint cur3 = fastmodulo(cur1_res,
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
// const uint curC = layout == 0 ? cur3 : cur0;
// const uint curT = layout == 0 ? cur0 : cur1;
// const uint curR = layout == 0 ? cur1 : cur2;
// const uint curS = layout == 0 ? cur2 : cur3;
const int4 curIdx = inputIndices<layout>(inKOffset, param);
// const int curD = posd_ori + curT * param.dilation2; // input w
// const int curH = posh_ori + curR * param.dilation1; // input h
// const int curW = posw_ori + curS * param.dilation0; // input w
const int curD = posd_ori + curIdx.y * param.dilation2; // input w
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
@ -214,43 +195,11 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
} else {
#pragma unroll
for (int i = 0; i < 4; ++i){
// const uint cur0 = fastdiv(inKOffset + i,
// layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
// const uint cur0_res = fastmodulo(inKOffset + i,
// layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
// const uint cur1 = fastdiv(cur0_res,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
// const uint cur1_res = fastmodulo(cur0_res,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
// const uint cur2 = fastdiv(cur1_res,
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
// const uint cur3 = fastmodulo(cur1_res,
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
// const uint curC = layout == 0 ? cur3 : cur0;
// const uint curT = layout == 0 ? cur0 : cur1;
// const uint curR = layout == 0 ? cur1 : cur2;
// const uint curS = layout == 0 ? cur2 : cur3;
const int4 curIdx = inputIndices<layout>(inKOffset + i, param);
// const int curD = posd_ori + curT * param.dilation2; // input w
// const int curH = posh_ori + curR * param.dilation1; // input h
// const int curW = posw_ori + curS * param.dilation0; // input w
const int curD = posd_ori + curIdx.y * param.dilation2; // input w
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
const int curC = curIdx.x;
// const uint cur0 = fastdiv(start_k + innerColA * 4 + i,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
// const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
// const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4 + i,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
// const uint curC = layout == 0 ? cur2 : cur0;
// const uint curR = layout == 0 ? cur0 : cur1;
// const uint curS = layout == 0 ? cur1 : cur2;
// const int curH = posh_ori + curR * param.d_h; // input h
// const int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKOffset + i < end_k){
int inOffsetTmp = layout == 0 ?
curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC:
@ -360,12 +309,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
const uint inKkOffset = innerColA * 4 + crs + BK;
#pragma unroll
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
// int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z;
// const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ;
// const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p;
// const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q;
// int inOffset = n * param.c * param.h * param.w ;
const unsigned int gemm_i = bx * BM + innerRowA + offset;
const unsigned int gemm_i = bx * BM + innerRowA + offset;
int n = (ksplit > 0) ? fastdiv(gemm_i, param.PQZ_fastdiv) : z;
const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv);
const int posd_ori = fastdiv((ksplit > 0) ? npqz_res: gemm_i, param.OHOW_fastdiv) * param.stride2 - param.padding2;
@ -379,28 +323,10 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
const int curC = curIdx.x;
// const uint cur0 = fastdiv(innerColA * 4 + crs + BK,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
// const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
// const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
// const uint curC = layout == 0 ? cur2 : cur0;
// const uint curR = layout == 0 ? cur0 : cur1;
// const uint curS = layout == 0 ? cur1 : cur2;
// const int curH = posh_ori + curR * param.d_h; // input h
// const int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKkOffset < end_k){
int inOffsetTmp = layout == 0 ?
curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC:
curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW;
// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && inKkOffset < end_k){
// int inOffsetTmp = layout == 0 ?
// curH * inChannelOffset + curW * param.c + curC:
// curC * inChannelOffset + curH * param.w + curW;
float4 tmp = reinterpret_cast<const float4 *>(&input[inOffset + inOffsetTmp])[0];
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 0] = tmp.x;
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + BM+PAD] = tmp.y;
@ -414,29 +340,11 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
} else {
#pragma unroll
for (int i = 0; i < 4; ++i){
// const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
// const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
// const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
// const uint curC = layout == 0 ? cur2 : cur0;
// const uint curR = layout == 0 ? cur0 : cur1;
// const uint curS = layout == 0 ? cur1 : cur2;
// const int curH = posh_ori + curR * param.d_h; // input h
// const int curW = posw_ori + curS * param.d_w; // input w
const int4 curIdx = inputIndices<layout>(inKkOffset + i, param);
const int curD = posd_ori + curIdx.y * param.dilation2; // input w
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
const int curC = curIdx.x;
// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){
// int inOffsetTmp = layout == 0 ?
// curH * inChannelOffset + curW * param.c + curC:
// curC * inChannelOffset + curH * param.w + curW;
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKkOffset + i < end_k){
int inOffsetTmp = layout == 0 ?
curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC:
@ -521,7 +429,6 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
const int col = (ksplit > 0) ? fastmodulo(gemm_i, param.PQZ_fastdiv) : gemm_i;
if (n < param.n && row < param.k && col < PQZ){
const uint outOffset = ksplit > 0 ?
// z * param.n * param.k * PQZ + n * param.k * PQZ + row * PQZ + col :
((z * param.n + n) * param.k + row) * PQZ + col :
(z * param.k + row) * PQZ + col;
output[outOffset] = smemoutput[output_lds_addr + subk * WARPSIZE];
@ -790,7 +697,7 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
const unsigned int K = param.c * param.r * param.s * param.t;
const uint weightKOffset = K; //param.c * param.r * param.s * param.t;
const uint inChannelOffset = param.c * param.w;
const uint inDepthOffset = param.h * param.c * param.w;
const uint inDepthOffset = param.h * param.c * param.w;
const uint inNOffset = param.c * param.w * param.h * param.d;
// loop bounds, constexpr where possible allows for loop unrolling
@ -854,7 +761,7 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
if (block_k != num_block_tiles_k){
const half* A_block_gmem = input;
const half* B_block_gmem = kernel + (block_n * BN * weightKOffset);
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK,
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK,
inNOffset, inDepthOffset, inChannelOffset, param);
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param);
}
@ -935,12 +842,9 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
for (int j = 0; j < 4; ++j){
const uint row = m_idx + subk + i * WN / 2;
const uint gemm_i = n_idx + j*32;
// const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
// const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
const int n = fastdiv(gemm_i, param.PQZ_fastdiv);
const int col = fastmodulo(gemm_i, param.PQZ_fastdiv);
if(n < param.n && row < param.k && col < PQZ){
// const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
const uint outOffset = (n * param.k + row) * PQZ + col;
uint idx = output_lds_addr + subk + j*32*BN/2;
idx = idx ^ ((idx & 0b1110000000) >> 4);
@ -1109,19 +1013,15 @@ void ggml_cuda_op_conv3d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
const uint KW = kernel->ne[0]; // kernel_w
const uint KH = kernel->ne[1]; // kernel_h
const uint KD = kernel->ne[2]; // kernel_h
// const uint IC = input->ne[2]; // input_channels
// const uint OC = kernel->ne[3]; // ouptut_chanles
// const uint B = input->ne[3]; // n_batches
param_t params = { B,
IC,
param_t params = { B,
IC,
IH, IW, ID,
OC,
OC,
KH, KW, KD,
ST_Y, ST_X, ST_Z,
PD_Y, PD_X, PD_Z,
DL_Y, DL_X, DL_Z,
ST_X, ST_Y, ST_Z,
PD_X, PD_Y, PD_Z,
DL_X, DL_Y, DL_Z,
OH, OW, OD,
init_fastdiv_values(KW*IC),
init_fastdiv_values(OW),

View File

@ -11,15 +11,15 @@ typedef struct{
unsigned int r; //filter height
unsigned int s; //filter width
unsigned int t; //filter depth
unsigned int stride0; //stride width
unsigned int stride1; //stride height
unsigned int stride0; //stride width
unsigned int stride1; //stride height
unsigned int stride2; //stride depth
unsigned int padding0; //padding width
unsigned int padding0; //padding width
unsigned int padding1; //padding height
unsigned int padding2; //padding depth
unsigned int dilation0; //dilation width
unsigned int dilation1; //dilation height
unsigned int dilation2; //dilation depth
unsigned int padding2; //padding depth
unsigned int dilation0; //dilation width
unsigned int dilation1; //dilation height
unsigned int dilation2; //dilation depth
unsigned int Oh; //output height
unsigned int Ow; //output width
unsigned int Od; //output depth
@ -39,17 +39,17 @@ template<const int layout>
__device__ __forceinline__ int4 inputIndices(const unsigned int kidx, param_t param) {
const unsigned int cur0 = fastdiv(kidx,
layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv);
const unsigned int cur0_res = fastmodulo(kidx,
layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv);
const unsigned int cur1 = fastdiv(cur0_res,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv);
const unsigned int cur1_res = fastmodulo(cur0_res,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv);
const unsigned int cur2 = fastdiv(cur1_res,
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
layout == 0 ? param.C_fastdiv : param.S_fastdiv);
const unsigned int cur3 = fastmodulo(cur1_res,
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
layout == 0 ? param.C_fastdiv : param.S_fastdiv);
const unsigned int curC = layout == 0 ? cur3 : cur0;
const unsigned int curT = layout == 0 ? cur0 : cur1;
const unsigned int curR = layout == 0 ? cur1 : cur2;
@ -90,9 +90,6 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
// const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
// const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
const unsigned int kidx = thread_col*8;
const int4 curIdx = inputIndices<0>(kidx, param);
const int curC = curIdx.x;
@ -172,17 +169,6 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
const int ohow_res = fastmodulo(npqz_res, param.OHOW_fastdiv);
const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1;
const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0;
// unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
// unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
// int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1;
// int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0;
// unsigned int inOffset = n * inNOffset;
// const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
// const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// int curH = posh_ori + curR * param.dilation1; // input h
// int curW = posw_ori + curS * param.dilation0; // input w
const int curD = posd_ori + curIdx.y * param.dilation2; // input d
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
@ -193,9 +179,6 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c){
int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC;
// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
// curR < param.r && curS < param.s && curC < param.c){
// const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[n * inNOffset + inOffsetTmp])[0];
} else{
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
@ -250,31 +233,18 @@ __device__ __forceinline__ void tileMemcpyLoadA(
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
// unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
// unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
// int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1;
// int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0;
unsigned int n = fastdiv(gemm_i, param.PQZ_fastdiv);
const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv);
const int posd_ori = fastdiv(npqz_res, param.OHOW_fastdiv) * param.stride2 - param.padding2;
const int ohow_res = fastmodulo(npqz_res, param.OHOW_fastdiv);
const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1;
const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0;
// unsigned int inOffset = n * param.c * param.h * param.w;
// const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset
// const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// int curH = posh_ori + curR * param.dilation1; // input h
// int curW = posw_ori + curS * param.dilation0; // input w
const int curD = posd_ori + curIdx.y * param.dilation2; // input d
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
const int curC = curIdx.x;
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c){
int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC;
// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
// curR < param.r && curS < param.s && curC < param.c){
// const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
dst_reg[i] = reinterpret_cast<const float4 *>(&src[n * inNOffset + inOffsetTmp])[0];
} else{
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
@ -322,9 +292,6 @@ __device__ __forceinline__ void tileMemcpyLoadB(
// compile time check that we provided the right amount of registers for storage
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
// const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset
// const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
const unsigned int kidx = block_k + thread_col*8;
const int4 curIdx = inputIndices<0>(kidx, param);
const int curC = curIdx.x;
@ -334,7 +301,6 @@ __device__ __forceinline__ void tileMemcpyLoadB(
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8;
// if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){
// TODO : move some checks outside of the loop
if (thread_row < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];

View File

@ -38,7 +38,9 @@ struct test_model {
void load_model(test_model & model, int ic, int oc, int iw, int ih, int id, int kw = 3, int kh = 3, int kd = 3, bool use_gpu = false ) {
void load_model(test_model & model, int ic, int oc, int iw, int ih, int id,
int kw = 3, int kh = 3, int kd = 3,
bool use_fp16 = true, bool use_gpu = false ) {
// create data
int KW = kw, KH = kh, KD = kd;
int IC = ic, OC = oc;
@ -72,9 +74,10 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int id, int
}
size_t buffer_size = 0;
{
// buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F32); // tensor a
buffer_size += KW * KH * KD * IC * OC * ggml_type_size(GGML_TYPE_F16); // tensor a
{ if(use_fp16)
buffer_size += KW * KH * KD * IC * OC * ggml_type_size(GGML_TYPE_F16); // tensor a
else
buffer_size += KW * KH * KD * IC * OC * ggml_type_size(GGML_TYPE_F32); // tensor a
buffer_size += IW * IH * ID * IC * N * ggml_type_size(GGML_TYPE_F32); // tensor b
buffer_size += 1024; // overhead
}
@ -122,8 +125,10 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int id, int
model.ctx = ggml_init(params);
// create tensors
model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, KW, KH, KD, IC*OC);
// model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, KW, KH, IC, OC);
if(use_fp16)
model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, KW, KH, KD, IC*OC);
else
model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, KW, KH, KD, IC*OC);
model.b = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, IW, IH, ID, IC*N);
// create a allocator
@ -134,11 +139,15 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int id, int
// load data to buffer
if(ggml_backend_is_cpu(model.backend)) {
memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a));
// memcpy(model.a->data, adata.data(), ggml_nbytes(model.a));
if(use_fp16)
memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a));
else
memcpy(model.a->data, adata.data(), ggml_nbytes(model.a));
} else {
ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a));
// ggml_backend_tensor_set(model.a, adata.data(), 0, ggml_nbytes(model.a));
if(use_fp16)
ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a));
else
ggml_backend_tensor_set(model.a, adata.data(), 0, ggml_nbytes(model.a));
}
// alloc memory
@ -155,7 +164,7 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int id, int
}
}
typedef struct ggml_cgraph* (*build_graph_t)(const test_model& model,
typedef struct ggml_cgraph* (*build_graph_t)(const test_model& model,
const int64_t i0, const int64_t i1, const int64_t i2);
struct ggml_cgraph * build_graph_0(const test_model& model, const int64_t ic, const int64_t n, const int64_t oc) {
@ -173,18 +182,27 @@ struct ggml_cgraph * build_graph_0(const test_model& model, const int64_t ic, co
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
// int s0 = 2;
// int s1 = 1;
// int s2 = 1;
// int p0 = 2;
// int p1 = 0;
// int p2 = 1;
// int d0 = 1;
// int d1 = 1;
// int d2 = 2;
int s0 = 1;
int s1 = 1;
int s2 = 1;
int p0 = 1;
int p1 = 1;
int p2 = 1;
int d0 = 1;
int d1 = 1;
int d2 = 1;
// recalculate for avoid fragmentation
struct ggml_tensor* conv2d_res = ggml_conv_3d(ctx0, model.a, model.b, ic, s0, s1, s2, p0, p1, p2, d0, d1, d2);
ggml_set_name(conv2d_res, "conv2d_res");
@ -227,6 +245,16 @@ struct ggml_cgraph * build_graph_1(const test_model& model, const int64_t ic, co
int d1 = 1;
int d2 = 1;
// int s0 = 2;
// int s1 = 1;
// int s2 = 1;
// int p0 = 2;
// int p1 = 0;
// int p2 = 1;
// int d0 = 1;
// int d1 = 1;
// int d2 = 2;
// recalculate for avoid fragmentation
// struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1);
// ggml_set_name(conv2d_res, "conv2d_res");
@ -236,7 +264,7 @@ struct ggml_cgraph * build_graph_1(const test_model& model, const int64_t ic, co
// struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1);
struct ggml_tensor* wino_res = ggml_conv_3d_direct(ctx0, model.a, model.b,
struct ggml_tensor* wino_res = ggml_conv_3d_direct(ctx0, model.a, model.b,
s0, s1, s2, p0, p1, p2, d0, d1, d2,
ic, n, oc);
ggml_set_name(wino_res, "wino_res");
@ -251,7 +279,7 @@ struct ggml_cgraph * build_graph_1(const test_model& model, const int64_t ic, co
std::vector<float> compute_graph(const test_model & model, ggml_gallocr_t allocr,
build_graph_t build_graph, int iters,
build_graph_t build_graph, int iters,
const int64_t ic, const int64_t n, const int64_t oc, double *t) {
struct ggml_cgraph * gf = build_graph(model, ic, n, oc);
@ -271,7 +299,6 @@ std::vector<float> compute_graph(const test_model & model, ggml_gallocr_t allocr
}
#endif
ggml_backend_graph_compute(model.backend, gf);
@ -289,8 +316,6 @@ std::vector<float> compute_graph(const test_model & model, ggml_gallocr_t allocr
double time_us = end_time - start_time;
time_us = time_us/iters;
// printf(" Taking %f ms\n ", time_us/1000);
//ggml_graph_print(gf);
struct ggml_tensor *res = NULL;
@ -316,12 +341,6 @@ int main(void)
{
ggml_time_init();
std::vector<std::tuple<int, int, int, int, int, int, int, int>> configs = {
// std::make_tuple(64,64,48,64,3,3),
// std::make_tuple(320,320,104,152,3,3),
// std::make_tuple(640,640,52,76,3,3),
// std::make_tuple(640,640,104,152,3,3),
// std::make_tuple(960,320,104,152,3,3),
// std::make_tuple(1280,1280,26,38,3,3),
std::make_tuple(320,1280,26,38,8,3,3,3),
std::make_tuple(1280,1280,26,38,8,3,3,3),
std::make_tuple(320,1280,52,76,8,3,3,3),
@ -330,29 +349,14 @@ int main(void)
std::make_tuple(1280,1280,104,152,8,3,3,3),
std::make_tuple(320,1280,208,304,4,3,3,3),
std::make_tuple(640,1280,208,304,4,3,3,3),
// std::make_tuple(1280,1280,26,38,1,1),
// std::make_tuple(256,128,768,1024,3,3),
// std::make_tuple(128,3,768,1024,3,3),
// std::make_tuple(256,128,768,1024,1,1),
// std::make_tuple(512,256,384,512,1,1),
// std::make_tuple(1280,640,52,76,3,3),
// std::make_tuple(1920,1280,26,38,3,3),
// std::make_tuple(2560,1280,26,38,3,3),
// std::make_tuple(320,1280,26,38,3,3),
// std::make_tuple(512,512,104,152,3,3),
// std::make_tuple(512,512,208,304,3,3),
// std::make_tuple(512,256,416,608,3,3),
// std::make_tuple(256,128,832,1216,3,3),
// std::make_tuple(256,256,832,1216,3,3),
// std::make_tuple(320,256,1024,1920)
};
int k = 0;
for (auto c : configs){
test_model model;
load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c),
std::get<3>(c), std::get<4>(c), std::get<5>(c), std::get<6>(c), std::get<7>(c), true);
load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c),
std::get<3>(c), std::get<4>(c), std::get<5>(c), std::get<6>(c), std::get<7>(c), true, true);
ggml_gallocr_t allocr = NULL;
allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend));
@ -366,11 +370,11 @@ int main(void)
// fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f);
struct ggml_cgraph * gf_res_0 = NULL;
struct ggml_cgraph * gf_res_0 = NULL;
int iterations = 20;
double run_time0;
std::vector<float> im2col_data = compute_graph(model, allocr, build_graph_0, iterations,
std::vector<float> im2col_data = compute_graph(model, allocr, build_graph_0, iterations,
std::get<0>(c), 1, std::get<1>(c), &run_time0);
ggml_gallocr_free(allocr);
@ -386,23 +390,22 @@ int main(void)
ggml_gallocr_reserve(allocr, gf);
size_t mem_size1 = ggml_gallocr_get_buffer_size(allocr, 0);
// fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f);
struct ggml_cgraph * gf_res_1 = NULL;
struct ggml_cgraph * gf_res_1 = NULL;
double run_time1;
// std::vector<float> wino_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1);
std::vector<float> conv2d_data = compute_graph(model, allocr, build_graph_1, iterations,
std::vector<float> conv2d_data = compute_graph(model, allocr, build_graph_1, iterations,
std::get<0>(c), 1, std::get<1>(c), &run_time1);
if(k==0) {
if(k==0) {
k = 1;
fprintf(stderr, "| (IC, OC, IW, IH, ID, KW, KH, KD) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n");
fprintf(stderr, "| --- | --- | --- | --- | --- \n");
}
fprintf(stderr, " | (%d, %d, %d, %d, %d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n",
std::get<0>(c), std::get<1>(c), std::get<2>(c),
fprintf(stderr, " | (%d, %d, %d, %d, %d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n",
std::get<0>(c), std::get<1>(c), std::get<2>(c),
std::get<3>(c), std::get<4>(c), std::get<5>(c),
std::get<6>(c), std::get<7>(c),
run_time0, mem_size0/1024.0f/1024.0f,
@ -412,7 +415,7 @@ int main(void)
// for(int i = 0; i < conv2d_data.size(); i++) {
// float diff = fabs(im2col_data[i] - conv2d_data[i]);
// // if(diff > 0.5) {
// printf("(%7.3f, %7.3f, %.2f, %d) \n",
// printf("(%7.3f, %7.3f, %f, %d) \n",
// im2col_data[i], conv2d_data[i],
// diff, i);
// // break;
@ -425,7 +428,5 @@ int main(void)
ggml_gallocr_free(allocr);
}
// printf("\nPerforming test:\n");
return 0;
}