clean up
This commit is contained in:
parent
1e568252b5
commit
2dfbbee73f
|
|
@ -8,6 +8,8 @@
|
||||||
typedef unsigned int uint;
|
typedef unsigned int uint;
|
||||||
constexpr uint WARPSIZE = 32;
|
constexpr uint WARPSIZE = 32;
|
||||||
|
|
||||||
|
|
||||||
|
//currently not use; in future for split-k kernels
|
||||||
static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) {
|
static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) {
|
||||||
const int row = blockIdx.x;
|
const int row = blockIdx.x;
|
||||||
const int col = threadIdx.x;
|
const int col = threadIdx.x;
|
||||||
|
|
@ -31,11 +33,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
float * __restrict__ output,
|
float * __restrict__ output,
|
||||||
const param_t param) {
|
const param_t param) {
|
||||||
|
|
||||||
// __shared__ char smem[4 * (TM*TN*NUM_THREADS <= (BM * BK + BK * (BN+PAD)) ? (BM * BK + BK * (BN+PAD)) : (TM*TN*NUM_THREADS))];
|
|
||||||
__shared__ char smem[sizeof(float) * (TM*TN*NUM_THREADS) <= sizeof(float) * 2 * (BM+PAD) * BK + sizeof(T)*2*BK * (BN+PAD) ?
|
__shared__ char smem[sizeof(float) * (TM*TN*NUM_THREADS) <= sizeof(float) * 2 * (BM+PAD) * BK + sizeof(T)*2*BK * (BN+PAD) ?
|
||||||
sizeof(float)*2*(BM+PAD)*BK + sizeof(T)*2*BK*(BN+PAD) : sizeof(float) * (TM*TN*NUM_THREADS)];
|
sizeof(float)*2*(BM+PAD)*BK + sizeof(T)*2*BK*(BN+PAD) : sizeof(float) * (TM*TN*NUM_THREADS)];
|
||||||
// __shared__ float smeminput[2 * BM * BK];
|
|
||||||
// __shared__ float smemweight[2 * BK * (BN+PAD)];
|
|
||||||
T *smemweight = reinterpret_cast<T *>(smem);
|
T *smemweight = reinterpret_cast<T *>(smem);
|
||||||
float *smeminput = reinterpret_cast<float *>(smem + 2 * BK * (BN+PAD) * sizeof(T));
|
float *smeminput = reinterpret_cast<float *>(smem + 2 * BK * (BN+PAD) * sizeof(T));
|
||||||
|
|
||||||
|
|
@ -48,12 +47,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
// Warp tile
|
// Warp tile
|
||||||
const uint lane_id = tx % WARPSIZE;
|
const uint lane_id = tx % WARPSIZE;
|
||||||
const uint warp_id = tx / WARPSIZE;
|
const uint warp_id = tx / WARPSIZE;
|
||||||
const int mma_tid_x = warp_id / (BN / WN); //(lane_id / 2) % 8;
|
const int mma_tid_x = warp_id / (BN / WN);
|
||||||
const int mma_tid_y = warp_id % (BN / WN); //(lane_id / 16) * 2 + (lane_id % 2);
|
const int mma_tid_y = warp_id % (BN / WN);
|
||||||
|
|
||||||
// lds addr
|
|
||||||
// int weight_lds_addr = (warp_id / 2) * 32 + mma_tid_y * 4;
|
|
||||||
// int input_lds_addr = (warp_id % 2) * 64 + mma_tid_x * 4;
|
|
||||||
|
|
||||||
// size of the warp subtile
|
// size of the warp subtile
|
||||||
constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER);
|
constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER);
|
||||||
|
|
@ -61,75 +56,34 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
constexpr uint WSUBN = WN / WNITER; // 32/2=16
|
constexpr uint WSUBN = WN / WNITER; // 32/2=16
|
||||||
|
|
||||||
// Placement of the thread in the warp subtile
|
// Placement of the thread in the warp subtile
|
||||||
// const uint threadIdxInWarp = tx % WARPSIZE; // [0, 31]
|
|
||||||
const uint threadColInWarp = lane_id % (WSUBN / TN); // i%(16/4)
|
const uint threadColInWarp = lane_id % (WSUBN / TN); // i%(16/4)
|
||||||
const uint threadRowInWarp = lane_id / (WSUBN / TN); // i/4
|
const uint threadRowInWarp = lane_id / (WSUBN / TN); // i/4
|
||||||
|
|
||||||
// int x = bx * BM + input_lds_addr;
|
|
||||||
// int y = by * BN + weight_lds_addr;
|
|
||||||
int z = blockIdx.z;
|
int z = blockIdx.z;
|
||||||
|
|
||||||
|
|
||||||
// float weight_ldg_reg[4];
|
|
||||||
// float input_ldg_reg[4];
|
|
||||||
// 当前线程处理的数据点在oh、ow上的坐标
|
|
||||||
// int posh_ori = ((bx * 128 + tx / 2 ) / param.Ow) * param.u - param.p;
|
|
||||||
// int posw_ori = ((bx * 128 + tx / 2 ) % param.Ow) * param.v - param.q;
|
|
||||||
// int posh_ori = fastdiv(bx * BM + tx / 2, param.OW_fastdiv) * param.u - param.p;
|
|
||||||
// int posw_ori = fastmodulo(bx * BM + tx / 2, param.OW_fastdiv) * param.v - param.q;
|
|
||||||
|
|
||||||
|
|
||||||
// int inOffset = (ksplit > 0): z * param.c * param.h * param.w ;
|
|
||||||
// int weiOffset = (by * BN + tx / 8 * 4) * param.c * param.r * param.s;
|
|
||||||
int inChannelOffset = layout == 0 ? param.c * param.w : param.h * param.w;
|
int inChannelOffset = layout == 0 ? param.c * param.w : param.h * param.w;
|
||||||
// int weightChannelOffset = param.r * param.s;
|
|
||||||
int weightKOffset = param.c * param.r * param.s;
|
int weightKOffset = param.c * param.r * param.s;
|
||||||
|
|
||||||
// uint ks, start_k;
|
|
||||||
|
|
||||||
// if constexpr (ksplit > 0){
|
|
||||||
// const uint ks = (weightKOffset + ksplit - 1) / ksplit;
|
|
||||||
// const uint start_k = z * ks;
|
|
||||||
// } else {
|
|
||||||
// const uint ks = weightKOffset;
|
|
||||||
// const uint start_k = 0;
|
|
||||||
// }
|
|
||||||
const uint ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset;
|
const uint ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset;
|
||||||
const uint start_k = (ksplit > 0)? z * ks: 0;
|
const uint start_k = (ksplit > 0)? z * ks: 0;
|
||||||
const uint end_k = min(start_k + ks, weightKOffset);
|
const uint end_k = min(start_k + ks, weightKOffset);
|
||||||
|
|
||||||
// sts addr
|
|
||||||
// int weight_sts_addr = (tx % 8) * 132 +
|
|
||||||
// (tx / 8) * 4;
|
|
||||||
int write_flag = 1;
|
int write_flag = 1;
|
||||||
T weight_frag[2][WNITER * TN];
|
T weight_frag[2][WNITER * TN];
|
||||||
float input_frag[2][WMITER * TM] = {0.f};
|
float input_frag[2][WMITER * TM] = {0.f};
|
||||||
float output_frag[WMITER * TM * WNITER * TN] = {0.f};
|
float output_frag[WMITER * TM * WNITER * TN] = {0.f};
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < 8; ++i)
|
|
||||||
// {
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int j = 0; j < 8; ++j)
|
|
||||||
// {
|
|
||||||
// output_frag[i][j] = 0;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// calculating the indices that this thread will load into SMEM
|
// calculating the indices that this thread will load into SMEM
|
||||||
// we'll load 128bit / 32bit = 4 elements per thread at each step
|
// we'll load 128bit / 32bit = 4 elements per thread at each step
|
||||||
const uint innerRowA = tx / (BK / 4);
|
const uint innerRowA = tx / (BK / 4);
|
||||||
const uint innerColA = tx % (BK / 4);
|
const uint innerColA = tx % (BK / 4);
|
||||||
constexpr uint rowStrideA = (NUM_THREADS * 4) / BK;
|
constexpr uint rowStrideA = (NUM_THREADS * 4) / BK;
|
||||||
// const uint innerRowB = tx / (BN / 4);
|
|
||||||
// const uint innerColB = tx % (BN / 4);
|
|
||||||
// constexpr uint rowStrideB = NUM_THREADS / (BN / 4);
|
|
||||||
|
|
||||||
// ldg
|
// ldg
|
||||||
const uint weight_sts_addr = innerRowA + innerColA * (BN+PAD) * 4;
|
const uint weight_sts_addr = innerRowA + innerColA * (BN+PAD) * 4;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) {
|
for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) {
|
||||||
if(vec_load){
|
if(vec_load){
|
||||||
// if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < param.c * param.r * param.s){
|
|
||||||
if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < end_k){
|
if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < end_k){
|
||||||
if constexpr (std::is_same_v<T, float>){
|
if constexpr (std::is_same_v<T, float>){
|
||||||
float4 tmp = reinterpret_cast<const float4 *>(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0];
|
float4 tmp = reinterpret_cast<const float4 *>(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0];
|
||||||
|
|
@ -138,26 +92,23 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z;
|
smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z;
|
||||||
smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w;
|
smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w;
|
||||||
}else{ // read 4 halves
|
}else{ // read 4 halves
|
||||||
// half val[4];
|
|
||||||
float2 tmp = reinterpret_cast<const float2 *>(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0];
|
float2 tmp = reinterpret_cast<const float2 *>(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0];
|
||||||
const half *val = reinterpret_cast<const half *>(&tmp);
|
const half *val = reinterpret_cast<const half *>(&tmp);
|
||||||
// val[1] = reinterpret_cast<half2 *>(&tmp.y);
|
|
||||||
smemweight[weight_sts_addr + offset + 0] = val[0];
|
smemweight[weight_sts_addr + offset + 0] = val[0];
|
||||||
smemweight[weight_sts_addr + offset + (BN+PAD)] = val[1];
|
smemweight[weight_sts_addr + offset + (BN+PAD)] = val[1];
|
||||||
smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = val[2];
|
smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = val[2];
|
||||||
smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = val[3];
|
smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = val[3];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; ++i){
|
for (int i = 0; i < 4; ++i){
|
||||||
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f;
|
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}else{
|
}else{
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; ++i){
|
for (int i = 0; i < 4; ++i){
|
||||||
if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 + i < end_k){
|
if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 + i < end_k){
|
||||||
// float4 tmp = reinterpret_cast<float4 *>(¶m.weight[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4])[0];
|
|
||||||
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4 + i];
|
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4 + i];
|
||||||
} else {
|
} else {
|
||||||
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f;
|
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f;
|
||||||
|
|
@ -167,14 +118,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// int curC = (tx / 32) / (param.r * param.s); // channel offset
|
|
||||||
// int curR = ((tx / 32) % (param.r * param.s)) / param.s; // kernel r offset
|
|
||||||
// int curS = ((tx / 32) % (param.r * param.s)) % param.s; // kernel s offset
|
|
||||||
|
|
||||||
// int curR = (tx % 2) * 4 / (param.s * param.c); // channel offset
|
|
||||||
// int curS = ((tx % 2) * 4 % (param.s * param.c)) / param.c; // kernel r offset
|
|
||||||
// int curC = ((tx % 2) * 4 % (param.s * param.c)) % param.c; // kernel s offset
|
|
||||||
|
|
||||||
const uint input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4;
|
const uint input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
|
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
|
||||||
|
|
@ -184,9 +127,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q;
|
const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q;
|
||||||
int inOffset = n * param.c * param.h * param.w ;
|
int inOffset = n * param.c * param.h * param.w ;
|
||||||
if(vec_load){
|
if(vec_load){
|
||||||
// const uint curR = fastdiv(start_k + innerColA * 4, param.SC_fastdiv); // channel offset
|
|
||||||
// const uint curS = fastdiv(fastmodulo(start_k + innerColA * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
|
||||||
// const uint curC = fastmodulo(fastmodulo(start_k + innerColA * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
|
||||||
const uint cur0 = fastdiv(start_k + innerColA * 4,
|
const uint cur0 = fastdiv(start_k + innerColA * 4,
|
||||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||||
const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4,
|
const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4,
|
||||||
|
|
@ -201,7 +141,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
const int curH = posh_ori + curR * param.d_h; // input h
|
const int curH = posh_ori + curR * param.d_h; // input h
|
||||||
const int curW = posw_ori + curS * param.d_w; // input w
|
const int curW = posw_ori + curS * param.d_w; // input w
|
||||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 < end_k){
|
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 < end_k){
|
||||||
// int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
|
||||||
int inOffsetTmp = layout == 0 ?
|
int inOffsetTmp = layout == 0 ?
|
||||||
curH * inChannelOffset + curW * param.c + curC:
|
curH * inChannelOffset + curW * param.c + curC:
|
||||||
curC * inChannelOffset + curH * param.w + curW;
|
curC * inChannelOffset + curH * param.w + curW;
|
||||||
|
|
@ -211,16 +150,13 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
smeminput[input_sts_addr + offset + 2*(BM+PAD)] = tmp.z;
|
smeminput[input_sts_addr + offset + 2*(BM+PAD)] = tmp.z;
|
||||||
smeminput[input_sts_addr + offset + 3*(BM+PAD)] = tmp.w;
|
smeminput[input_sts_addr + offset + 3*(BM+PAD)] = tmp.w;
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; ++i)
|
for (int i = 0; i < 4; ++i)
|
||||||
smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f;
|
smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; ++i){
|
for (int i = 0; i < 4; ++i){
|
||||||
// const uint curR = fastdiv(start_k + innerColA * 4 + i, param.SC_fastdiv); // channel offset
|
|
||||||
// const uint curS = fastdiv(fastmodulo(start_k + innerColA * 4 + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
|
||||||
// const uint curC = fastmodulo(fastmodulo(start_k + innerColA * 4 + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
|
||||||
const uint cur0 = fastdiv(start_k + innerColA * 4 + i,
|
const uint cur0 = fastdiv(start_k + innerColA * 4 + i,
|
||||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||||
const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i,
|
const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i,
|
||||||
|
|
@ -235,7 +171,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
const int curH = posh_ori + curR * param.d_h; // input h
|
const int curH = posh_ori + curR * param.d_h; // input h
|
||||||
const int curW = posw_ori + curS * param.d_w; // input w
|
const int curW = posw_ori + curS * param.d_w; // input w
|
||||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 + i < end_k){
|
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 + i < end_k){
|
||||||
// int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
|
||||||
int inOffsetTmp = layout == 0 ?
|
int inOffsetTmp = layout == 0 ?
|
||||||
curH * inChannelOffset + curW * param.c + curC:
|
curH * inChannelOffset + curW * param.c + curC:
|
||||||
curC * inChannelOffset + curH * param.w + curW;
|
curC * inChannelOffset + curH * param.w + curW;
|
||||||
|
|
@ -246,40 +181,9 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// sts
|
|
||||||
// for (int i = 0; i < 4; ++i)
|
|
||||||
// {
|
|
||||||
// smemweight[weight_sts_addr + i*132] = weight_ldg_reg[i];
|
|
||||||
// }
|
|
||||||
// for (int i = 0; i < 4; ++i)
|
|
||||||
// {
|
|
||||||
// smeminput[input_sts_addr + i * 128] = input_ldg_reg[i];
|
|
||||||
// }
|
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// if(tx == 0 && bx == 0 && by == 0 && z == 0){
|
|
||||||
// for(int i=0; i < 128; ++i)
|
|
||||||
// printf("%.2f,", smeminput[i]);
|
|
||||||
// printf("\n");
|
|
||||||
// for(int i=128; i < 256; ++i)
|
|
||||||
// printf("%.2f,", smeminput[i]);
|
|
||||||
// printf("\n");
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if(tx == 0 && bx == 0 && by == 0 && z == 0){
|
|
||||||
// printf("%u, %u, %u, %u \n", innerRowA, innerColA, rowStrideA, weight_sts_addr);
|
|
||||||
// for(int i=0; i < 16; ++i)
|
|
||||||
// printf("%f,", smemweight[i]);
|
|
||||||
// printf("\n");
|
|
||||||
// for(int i=0; i < 16; ++i)
|
|
||||||
// printf("%f,", param.weight[i*param.c*param.r*param.s]);
|
|
||||||
// printf("\n");
|
|
||||||
// }
|
|
||||||
|
|
||||||
// lds
|
// lds
|
||||||
// int input_lds_addr = (warp_id % 2) * 64 + mma_tid_x * 4;
|
|
||||||
const uint input_lds_addr = mma_tid_x * WM;
|
const uint input_lds_addr = mma_tid_x * WM;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx)
|
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx)
|
||||||
|
|
@ -288,7 +192,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
input_frag[0][wSubRowIdx * TM + i] = smeminput[input_lds_addr + wSubRowIdx * WSUBM +
|
input_frag[0][wSubRowIdx * TM + i] = smeminput[input_lds_addr + wSubRowIdx * WSUBM +
|
||||||
threadRowInWarp * TM + i];
|
threadRowInWarp * TM + i];
|
||||||
|
|
||||||
// int weight_lds_addr = (warp_id / 2) * 32 + mma_tid_y * 4;
|
|
||||||
const uint weight_lds_addr = mma_tid_y * WN;
|
const uint weight_lds_addr = mma_tid_y * WN;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx)
|
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx)
|
||||||
|
|
@ -297,95 +200,19 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
weight_frag[0][wSubColIdx * TN + i] = smemweight[weight_lds_addr + wSubColIdx * WSUBN +
|
weight_frag[0][wSubColIdx * TN + i] = smemweight[weight_lds_addr + wSubColIdx * WSUBN +
|
||||||
threadColInWarp * TN + i];
|
threadColInWarp * TN + i];
|
||||||
|
|
||||||
// #pragma unroll
|
for (int crs = start_k; crs < end_k; crs += BK) {
|
||||||
// for (int i = 0; i < 4; ++i)
|
|
||||||
// {
|
|
||||||
// weight_frag[0][i] = smemweight[weight_lds_addr + i];
|
|
||||||
// weight_frag[0][i + 4] = smemweight[weight_lds_addr + i + 16];
|
|
||||||
// }
|
|
||||||
// if(tx == 0 && bx == 0 && by == 0 && z == 0)
|
|
||||||
// {
|
|
||||||
// printf("weight_ldg_reg:%f,%f,%f,%f\n", weight_frag[0][0], weight_frag[0][1], weight_frag[0][2], weight_frag[0][3]);
|
|
||||||
// printf("weight_ldg_reg:%f,%f,%f,%f\n", weight_frag[0][4], weight_frag[0][5], weight_frag[0][6], weight_frag[0][7]);
|
|
||||||
// }
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < 4; ++i)
|
|
||||||
// {
|
|
||||||
// input_frag[0][i] = smeminput[input_lds_addr + i];
|
|
||||||
// input_frag[0][i + 4] = smeminput[input_lds_addr + i + 32];
|
|
||||||
// }
|
|
||||||
|
|
||||||
|
|
||||||
for (int crs = start_k; crs < end_k; crs += BK)
|
|
||||||
{
|
|
||||||
// ldg
|
|
||||||
// if (by * BN + tx / 2 < param.k && tx % 2 * 4 < param.c * param.r * param.s){
|
|
||||||
// float4 tmp = reinterpret_cast<float4 *>(¶m.weight[by * BN + tx / 2 * weightKOffset + tx % 2 * 4 + crs + 8])[0];
|
|
||||||
// weight_ldg_reg[0] = tmp.x;
|
|
||||||
// weight_ldg_reg[1] = tmp.y;
|
|
||||||
// weight_ldg_reg[2] = tmp.z;
|
|
||||||
// weight_ldg_reg[3] = tmp.w;
|
|
||||||
// } else {
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < 4; ++i)
|
|
||||||
// weight_ldg_reg[i] = 0.0;
|
|
||||||
// }
|
|
||||||
// curR = (crs + 8 + tx % 2 * 4) / (param.s * param.c); // channel offset
|
|
||||||
// curS = ((crs + 8 + tx % 2 * 4) % (param.s * param.c)) / param.c; // kernel r offset
|
|
||||||
// curC = ((crs + 8 + tx % 2 * 4) % (param.s * param.c)) % param.c; // kernel s offset
|
|
||||||
// curR = fastdiv(crs + 8 + (tx % 2) * 4, param.SC_fastdiv); // channel offset
|
|
||||||
// curS = fastdiv(fastmodulo(crs + 8 + (tx % 2) * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
|
||||||
// curC = fastmodulo(fastmodulo(crs + 8 + (tx % 2) * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
|
||||||
|
|
||||||
// int curH = posh_ori + curR; // input h
|
|
||||||
// int curW = posw_ori + curS; // input w
|
|
||||||
// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h){
|
|
||||||
// int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
|
||||||
|
|
||||||
// // float4 tmp = reinterpret_cast<float4 *>(¶m.input[inOffset + inOffsetTmp])[0];
|
|
||||||
// // input_ldg_reg[0] = tmp.x;
|
|
||||||
// // input_ldg_reg[1] = tmp.y;
|
|
||||||
// // input_ldg_reg[2] = tmp.z;
|
|
||||||
// // input_ldg_reg[3] = tmp.w;
|
|
||||||
// reinterpret_cast<float4 *>(&input_ldg_reg[0])[0] = reinterpret_cast<float4 *>(¶m.input[inOffset + inOffsetTmp])[0]; } else {
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < 4; ++i)
|
|
||||||
// input_ldg_reg[i] = 0.0;
|
|
||||||
// }
|
|
||||||
|
|
||||||
int load_flag = write_flag ^ 1;
|
int load_flag = write_flag ^ 1;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int subcrs = 0; subcrs < BK - 1; ++subcrs)
|
for (int subcrs = 0; subcrs < BK - 1; ++subcrs)
|
||||||
{
|
{
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < 4; ++i)
|
|
||||||
// {
|
|
||||||
// weight_frag[(subcrs + 1) % 2][i] = smemweight[load_flag * (BN+4) * 8 + weight_lds_addr + (subcrs + 1) * (BN+4) + i];
|
|
||||||
// weight_frag[(subcrs + 1) % 2][i + 4] = smemweight[load_flag * (BN+4) * 8 + weight_lds_addr + (subcrs + 1) * (BN+4) + i + 16];
|
|
||||||
// }
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx)
|
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx)
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint i = 0; i < TN; ++i)
|
for (uint i = 0; i < TN; ++i)
|
||||||
weight_frag[(subcrs + 1) % 2][wSubColIdx * TN + i] = smemweight[load_flag * (BN+PAD) * BK +
|
weight_frag[(subcrs + 1) % 2][wSubColIdx * TN + i] = smemweight[load_flag * (BN+PAD) * BK +
|
||||||
(subcrs + 1) * (BN+PAD) + weight_lds_addr + wSubColIdx * WSUBN + threadColInWarp * TN + i];
|
(subcrs + 1) * (BN+PAD) + weight_lds_addr + wSubColIdx * WSUBN + threadColInWarp * TN + i];
|
||||||
// float* base_ptr = smemweight + load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132;
|
|
||||||
|
|
||||||
// // first 4 values -> weight_frag[...][0..3]
|
|
||||||
// float4 v0 = *reinterpret_cast<const float4*>(base_ptr);
|
|
||||||
|
|
||||||
// // next 4 values (offset +16) -> weight_frag[...][4..7]
|
|
||||||
// float4 v1 = *reinterpret_cast<const float4*>(base_ptr + 16);
|
|
||||||
|
|
||||||
// // unpack into weight_frag
|
|
||||||
// *reinterpret_cast<float4*>(&weight_frag[(subcrs + 1) % 2][0]) = v0;
|
|
||||||
// *reinterpret_cast<float4*>(&weight_frag[(subcrs + 1) % 2][4]) = v1;
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < 4; ++i)
|
|
||||||
// {
|
|
||||||
// input_frag[(subcrs + 1) % 2][i] = smeminput[load_flag * BM * 8 + input_lds_addr + (subcrs + 1) * BM + i];
|
|
||||||
// input_frag[(subcrs + 1) % 2][i + 4] = smeminput[load_flag * BM * 8 + input_lds_addr + (subcrs + 1) * BM + i + 32];
|
|
||||||
// }
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx)
|
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx)
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
|
@ -393,15 +220,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
input_frag[(subcrs + 1) % 2][wSubRowIdx * TM + i] = smeminput[load_flag * (BM+PAD) * BK +
|
input_frag[(subcrs + 1) % 2][wSubRowIdx * TM + i] = smeminput[load_flag * (BM+PAD) * BK +
|
||||||
(subcrs + 1) * (BM+PAD) + input_lds_addr + wSubRowIdx * WSUBM + threadRowInWarp * TM + i];
|
(subcrs + 1) * (BM+PAD) + input_lds_addr + wSubRowIdx * WSUBM + threadRowInWarp * TM + i];
|
||||||
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < 8; ++i)
|
|
||||||
// {
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int j = 0; j < 8; ++j)
|
|
||||||
// {
|
|
||||||
// output_frag[i][j] += weight_frag[subcrs % 2][i] * input_frag[subcrs % 2][j];
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// execute warptile matmul
|
// execute warptile matmul
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) {
|
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) {
|
||||||
|
|
@ -416,15 +234,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
(wSubColIdx * TN) + resIdxN] +=
|
(wSubColIdx * TN) + resIdxN] +=
|
||||||
input_frag[subcrs % 2][wSubRowIdx * TM + resIdxM] *
|
input_frag[subcrs % 2][wSubRowIdx * TM + resIdxM] *
|
||||||
ggml_cuda_cast<float>(weight_frag[subcrs % 2][wSubColIdx * TN + resIdxN]);
|
ggml_cuda_cast<float>(weight_frag[subcrs % 2][wSubColIdx * TN + resIdxN]);
|
||||||
// if(tx == 0 && bx == 0 && by == 0 && z == 0){
|
|
||||||
// printf("subcrs:%d, i:%d, j:%d, %f * %f = %f, acc = %f\n", subcrs, wSubRowIdx * TM + resIdxM, wSubColIdx * TN + resIdxN,
|
|
||||||
// input_frag[subcrs % 2][wSubRowIdx * TM + resIdxM],
|
|
||||||
// weight_frag[subcrs % 2][wSubColIdx * TN + resIdxN],
|
|
||||||
// input_frag[subcrs % 2][wSubRowIdx * TM + resIdxM] *
|
|
||||||
// weight_frag[subcrs % 2][wSubColIdx * TN + resIdxN],
|
|
||||||
// output_frag[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) +
|
|
||||||
// (wSubColIdx * TN) + resIdxN]);
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -450,12 +259,12 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 3*(BN+PAD)] = val[3];
|
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 3*(BN+PAD)] = val[3];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; ++i)
|
for (int i = 0; i < 4; ++i)
|
||||||
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f;
|
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f;
|
||||||
}
|
}
|
||||||
}else{
|
}else{
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; ++i){
|
for (int i = 0; i < 4; ++i){
|
||||||
if (by * BN + innerRowA + offset < param.k && innerColA * 4 + crs + BK + i < end_k){
|
if (by * BN + innerRowA + offset < param.k && innerColA * 4 + crs + BK + i < end_k){
|
||||||
// float4 tmp = reinterpret_cast<float4 *>(¶m.weight[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK + i])[0];
|
// float4 tmp = reinterpret_cast<float4 *>(¶m.weight[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK + i])[0];
|
||||||
|
|
@ -474,9 +283,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q;
|
const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q;
|
||||||
int inOffset = n * param.c * param.h * param.w ;
|
int inOffset = n * param.c * param.h * param.w ;
|
||||||
if(vec_load){
|
if(vec_load){
|
||||||
// const uint curR = fastdiv(innerColA * 4 + crs + BK, param.SC_fastdiv); // channel offset
|
|
||||||
// const uint curS = fastdiv(fastmodulo(innerColA * 4 + crs + BK, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
|
||||||
// const uint curC = fastmodulo(fastmodulo(innerColA * 4 + crs + BK, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
|
||||||
const uint cur0 = fastdiv(innerColA * 4 + crs + BK,
|
const uint cur0 = fastdiv(innerColA * 4 + crs + BK,
|
||||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||||
const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK,
|
const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK,
|
||||||
|
|
@ -507,11 +313,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f;
|
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; ++i){
|
for (int i = 0; i < 4; ++i){
|
||||||
// const uint curR = fastdiv(innerColA * 4 + crs + BK + i, param.SC_fastdiv); // channel offset
|
|
||||||
// const uint curS = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
|
||||||
// const uint curC = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
|
||||||
const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i,
|
const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i,
|
||||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||||
const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i,
|
const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i,
|
||||||
|
|
@ -527,7 +330,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
const int curH = posh_ori + curR * param.d_h; // input h
|
const int curH = posh_ori + curR * param.d_h; // input h
|
||||||
const int curW = posw_ori + curS * param.d_w; // input w
|
const int curW = posw_ori + curS * param.d_w; // input w
|
||||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){
|
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){
|
||||||
// int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
|
||||||
int inOffsetTmp = layout == 0 ?
|
int inOffsetTmp = layout == 0 ?
|
||||||
curH * inChannelOffset + curW * param.c + curC:
|
curH * inChannelOffset + curW * param.c + curC:
|
||||||
curC * inChannelOffset + curH * param.w + curW;
|
curC * inChannelOffset + curH * param.w + curW;
|
||||||
|
|
@ -538,17 +340,10 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// sts
|
|
||||||
// for (int i = 0; i < 4; ++i)
|
|
||||||
// {
|
|
||||||
// smemweight[write_flag * (BN+4) * 8 + weight_sts_addr + i * (BN+4)] = weight_ldg_reg[i];
|
|
||||||
// }
|
|
||||||
// for (int i = 0; i < 4; ++i)
|
|
||||||
// {
|
|
||||||
// smeminput[write_flag * BM * 8 + input_sts_addr + i * BM] = input_ldg_reg[i];
|
|
||||||
// }
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
write_flag ^= 1;
|
write_flag ^= 1;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx)
|
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx)
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
|
@ -561,18 +356,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
for (uint i = 0; i < TN; ++i)
|
for (uint i = 0; i < TN; ++i)
|
||||||
weight_frag[0][wSubColIdx * TN + i] = smemweight[(load_flag ^ 1) * (BN+PAD) * BK +
|
weight_frag[0][wSubColIdx * TN + i] = smemweight[(load_flag ^ 1) * (BN+PAD) * BK +
|
||||||
weight_lds_addr + wSubColIdx * WSUBN + threadColInWarp * TN + i];
|
weight_lds_addr + wSubColIdx * WSUBN + threadColInWarp * TN + i];
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < 4; ++i)
|
|
||||||
// {
|
|
||||||
// weight_frag[0][i] = smemweight[(load_flag ^ 1) * (BN+4) * 8 + weight_lds_addr + i];
|
|
||||||
// weight_frag[0][i + 4] = smemweight[(load_flag ^ 1) * (BN+4) * 8 + weight_lds_addr + i + 16];
|
|
||||||
// }
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < 4; ++i)
|
|
||||||
// {
|
|
||||||
// input_frag[0][i] = smeminput[(load_flag ^ 1) * BM * 8 + input_lds_addr + i];
|
|
||||||
// input_frag[0][i + 4] = smeminput[(load_flag ^ 1) * BM * 8 + input_lds_addr + i + 32];
|
|
||||||
// }
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) {
|
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
|
@ -590,100 +373,12 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < 8; ++i)
|
|
||||||
// {
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int j = 0; j < 8; ++j)
|
|
||||||
// {
|
|
||||||
// output_frag[i][j] += weight_frag[1][i] * input_frag[1][j];
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// if(tx == 59 && bx == 0 && by == 0 && z == 0){
|
|
||||||
// for (int i = 0; i < WMITER * TM * WNITER * TN; ++i){
|
|
||||||
// printf("%f,", output_frag[i]);
|
|
||||||
// if((i+1) % (WNITER * TN) == 0)
|
|
||||||
// printf("\n");
|
|
||||||
// }
|
|
||||||
// printf("\n");
|
|
||||||
// }
|
|
||||||
// if(tx == 59 && bx == 0 && by == 0 && z == 0){
|
|
||||||
// int cnt[3] = {0};
|
|
||||||
// float values[3] = {-1.f};
|
|
||||||
// for (int i = 0; i < WMITER * TM * WNITER * TN; ++i){
|
|
||||||
// for(int j = 0; j < 3; j++){
|
|
||||||
// if (output_frag[i] == values[j]){
|
|
||||||
// cnt[j]++;
|
|
||||||
// break;
|
|
||||||
// } else{
|
|
||||||
// if (cnt[j] == 0){
|
|
||||||
// values[j] = output_frag[i];
|
|
||||||
// cnt[j]++;
|
|
||||||
// break;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// for(int j = 0; j < 3; j++){
|
|
||||||
// if(values[j] != -1.f)
|
|
||||||
// printf("value: %f, cnt: %d \n", values[j], cnt[j]);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// reuse smem
|
// reuse smem
|
||||||
float *smemoutput = reinterpret_cast<float *>(smem);
|
float *smemoutput = reinterpret_cast<float *>(smem);
|
||||||
// float *smembias = reinterpret_cast<float *>(smem + 16 * 1024);
|
|
||||||
|
|
||||||
// bias ldg/sts
|
|
||||||
// if (tx < BN)
|
|
||||||
// {
|
|
||||||
// smembias[tx] = param.bias[by * BN + tx];
|
|
||||||
// }
|
|
||||||
|
|
||||||
// constexpr uint OUTMITER = (TM * TN * WNITER * WMITER * NUM_THREADS) / (2 * BK * (BM + BN)) / OUTNITER;
|
|
||||||
// const uint WMITER_TM_OUTMITER = WMITER * TM / OUTMITER;
|
|
||||||
// const uint WNITER_TN_OUTNITER = WNITER * TN / OUTNITER;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// // uint32_t bias_lds_addr = warp_id / 2 * 32;
|
|
||||||
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < 2; ++i)
|
|
||||||
// {
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int j = 0; j < 2; ++j)
|
|
||||||
// {
|
|
||||||
// __syncthreads();
|
|
||||||
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int subi = 0; subi < 4; ++subi)
|
|
||||||
// {
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int subj = 0; subj < 4; ++subj)
|
|
||||||
// {
|
|
||||||
// // output sts
|
|
||||||
// smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj];
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// __syncthreads();
|
|
||||||
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int subk = 0; subk < 16; ++subk)
|
|
||||||
// {
|
|
||||||
// int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32;
|
|
||||||
// if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow)
|
|
||||||
// param.output[outOffset] = smemoutput[output_lds_addr + subk * 32];
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
const uint output_lds_addr = warp_id * WSUBM * WSUBN + lane_id;
|
const uint output_lds_addr = warp_id * WSUBM * WSUBN + lane_id;
|
||||||
// const uint m_idx = by * BN + mma_tid_y * WN + threadColInWarp * WNITER_TN_OUTNITER;
|
|
||||||
// const uint n_idx = bx * BM + mma_tid_x * WM + threadRowInWarp * WMITER_TM_OUTMITER;
|
|
||||||
// const uint output_sts_addr = warp_id * WMITER_TM_OUTMITER * WNITER_TN_OUTNITER * WARPSIZE +
|
|
||||||
// (threadRowInWarp * (WSUBN / TN) + threadColInWarp) * WMITER_TM_OUTMITER * WNITER_TN_OUTNITER;
|
|
||||||
const uint output_sts_addr = mma_tid_x * BN / WN * TM * TN * WARPSIZE + mma_tid_y * TM * TN * WARPSIZE +
|
const uint output_sts_addr = mma_tid_x * BN / WN * TM * TN * WARPSIZE + mma_tid_y * TM * TN * WARPSIZE +
|
||||||
threadColInWarp * TN * WSUBM + threadRowInWarp * TM;
|
threadColInWarp * TN * WSUBM + threadRowInWarp * TM;
|
||||||
const uint m_idx = by * BN + mma_tid_y * WN;
|
const uint m_idx = by * BN + mma_tid_y * WN;
|
||||||
|
|
@ -716,9 +411,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
const int n = (ksplit > 0) ? gemm_i / PQ : z;
|
const int n = (ksplit > 0) ? gemm_i / PQ : z;
|
||||||
const int col = (ksplit > 0) ? gemm_i % PQ : gemm_i;
|
const int col = (ksplit > 0) ? gemm_i % PQ : gemm_i;
|
||||||
if (n < param.n && row < param.k && col < param.Oh * param.Ow){
|
if (n < param.n && row < param.k && col < param.Oh * param.Ow){
|
||||||
// int outOffset = z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + (n_idx + j * 32);
|
|
||||||
// if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow)
|
|
||||||
// param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32];
|
|
||||||
const uint outOffset = ksplit > 0 ?
|
const uint outOffset = ksplit > 0 ?
|
||||||
z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow +
|
z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow +
|
||||||
row * param.Oh * param.Ow + col :
|
row * param.Oh * param.Ow + col :
|
||||||
|
|
@ -736,8 +428,7 @@ template <unsigned int mma_tiles_per_warp_m, unsigned int mma_tiles_per_warp_k,
|
||||||
__device__ __forceinline__ void ldmatrix_a(
|
__device__ __forceinline__ void ldmatrix_a(
|
||||||
const half* src,
|
const half* src,
|
||||||
half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]
|
half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]
|
||||||
)
|
){
|
||||||
{
|
|
||||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||||
static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 4");
|
static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 4");
|
||||||
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
|
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
|
||||||
|
|
@ -748,7 +439,7 @@ __device__ __forceinline__ void ldmatrix_a(
|
||||||
swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2);
|
swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2);
|
||||||
uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset);
|
uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset);
|
||||||
constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes
|
constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes
|
||||||
|
|
||||||
// 0
|
// 0
|
||||||
asm volatile (
|
asm volatile (
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||||
|
|
@ -782,7 +473,7 @@ __device__ __forceinline__ void ldmatrix_a(
|
||||||
);
|
);
|
||||||
|
|
||||||
src_addr ^= 0b10000;
|
src_addr ^= 0b10000;
|
||||||
|
|
||||||
// 1
|
// 1
|
||||||
asm volatile (
|
asm volatile (
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||||
|
|
@ -814,7 +505,7 @@ __device__ __forceinline__ void ldmatrix_a(
|
||||||
: "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1])
|
: "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1])
|
||||||
: "r"(src_addr + 96 * smem_stride_)
|
: "r"(src_addr + 96 * smem_stride_)
|
||||||
);
|
);
|
||||||
|
|
||||||
src_addr ^= 0b110000;
|
src_addr ^= 0b110000;
|
||||||
|
|
||||||
// 2
|
// 2
|
||||||
|
|
@ -892,31 +583,19 @@ template <unsigned int mma_tiles_per_warp_k, unsigned int mma_tiles_per_warp_n,
|
||||||
__device__ __forceinline__ void ldmatrix_b(
|
__device__ __forceinline__ void ldmatrix_b(
|
||||||
const half* src,
|
const half* src,
|
||||||
half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]
|
half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]
|
||||||
)
|
){
|
||||||
{
|
|
||||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||||
|
|
||||||
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
|
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
|
||||||
static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8");
|
static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8");
|
||||||
|
|
||||||
uint32_t (®_) [4][8] = reinterpret_cast<uint32_t(&)[4][8]>(reg);
|
uint32_t (®_) [4][8] = reinterpret_cast<uint32_t(&)[4][8]>(reg);
|
||||||
// const unsigned int logical_offset = ((threadIdx.x % 8) * smem_stride) + (((threadIdx.x % 32) / 8) * 8);
|
|
||||||
// unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b11100000000) >> 5);
|
|
||||||
// uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset);
|
|
||||||
// constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes
|
|
||||||
unsigned int logical_offset = (threadIdx.x % 32) * smem_stride;
|
unsigned int logical_offset = (threadIdx.x % 32) * smem_stride;
|
||||||
unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4);
|
unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4);
|
||||||
swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2);
|
swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2);
|
||||||
uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset);
|
uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset);
|
||||||
constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes
|
constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes
|
||||||
|
|
||||||
|
|
||||||
// asm volatile (
|
|
||||||
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
|
||||||
// "{%0, %1, %2, %3}, [%4];"
|
|
||||||
// : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3])
|
|
||||||
// : "r"(src_addr)
|
|
||||||
// );
|
|
||||||
|
|
||||||
// 0
|
// 0
|
||||||
asm volatile (
|
asm volatile (
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||||
|
|
@ -927,18 +606,15 @@ __device__ __forceinline__ void ldmatrix_b(
|
||||||
|
|
||||||
|
|
||||||
asm volatile (
|
asm volatile (
|
||||||
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||||
"{%0, %1, %2, %3}, [%4];"
|
"{%0, %1, %2, %3}, [%4];"
|
||||||
: "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7])
|
: "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7])
|
||||||
// : "r"(src_addr ^ 0b1000000)
|
|
||||||
: "r"(src_addr + 32 * smem_stride_)
|
: "r"(src_addr + 32 * smem_stride_)
|
||||||
);
|
);
|
||||||
|
|
||||||
src_addr ^= 0b10000;
|
src_addr ^= 0b10000;
|
||||||
|
|
||||||
asm volatile (
|
asm volatile (
|
||||||
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||||
"{%0, %1, %2, %3}, [%4];"
|
"{%0, %1, %2, %3}, [%4];"
|
||||||
: "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3])
|
: "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3])
|
||||||
|
|
@ -946,19 +622,15 @@ __device__ __forceinline__ void ldmatrix_b(
|
||||||
);
|
);
|
||||||
|
|
||||||
asm volatile (
|
asm volatile (
|
||||||
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||||
"{%0, %1, %2, %3}, [%4];"
|
"{%0, %1, %2, %3}, [%4];"
|
||||||
: "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7])
|
: "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7])
|
||||||
// : "r"(src_addr ^ 0b1000000)
|
|
||||||
: "r"(src_addr + 32 * smem_stride_)
|
: "r"(src_addr + 32 * smem_stride_)
|
||||||
);
|
);
|
||||||
|
|
||||||
// src_addr += 8 * smem_stride_;
|
|
||||||
src_addr ^= 0b110000;
|
src_addr ^= 0b110000;
|
||||||
|
|
||||||
asm volatile (
|
asm volatile (
|
||||||
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||||
"{%0, %1, %2, %3}, [%4];"
|
"{%0, %1, %2, %3}, [%4];"
|
||||||
: "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3])
|
: "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3])
|
||||||
|
|
@ -966,18 +638,15 @@ __device__ __forceinline__ void ldmatrix_b(
|
||||||
);
|
);
|
||||||
|
|
||||||
asm volatile (
|
asm volatile (
|
||||||
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||||
"{%0, %1, %2, %3}, [%4];"
|
"{%0, %1, %2, %3}, [%4];"
|
||||||
: "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7])
|
: "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7])
|
||||||
// : "r"(src_addr ^ 0b1000000)
|
|
||||||
: "r"(src_addr + 32 * smem_stride_)
|
: "r"(src_addr + 32 * smem_stride_)
|
||||||
);
|
);
|
||||||
|
|
||||||
src_addr ^= 0b10000;
|
src_addr ^= 0b10000;
|
||||||
|
|
||||||
asm volatile (
|
asm volatile (
|
||||||
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||||
"{%0, %1, %2, %3}, [%4];"
|
"{%0, %1, %2, %3}, [%4];"
|
||||||
: "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3])
|
: "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3])
|
||||||
|
|
@ -985,11 +654,9 @@ __device__ __forceinline__ void ldmatrix_b(
|
||||||
);
|
);
|
||||||
|
|
||||||
asm volatile (
|
asm volatile (
|
||||||
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||||
"{%0, %1, %2, %3}, [%4];"
|
"{%0, %1, %2, %3}, [%4];"
|
||||||
: "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7])
|
: "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7])
|
||||||
// : "r"(src_addr ^ 0b1000000)
|
|
||||||
: "r"(src_addr + 32 * smem_stride_)
|
: "r"(src_addr + 32 * smem_stride_)
|
||||||
);
|
);
|
||||||
#else
|
#else
|
||||||
|
|
@ -1006,37 +673,27 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||||
half * __restrict__ output,
|
half * __restrict__ output,
|
||||||
const param_t param) {
|
const param_t param) {
|
||||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||||
constexpr unsigned int MMA_M = 16;
|
|
||||||
constexpr unsigned int MMA_N = 8;
|
|
||||||
|
|
||||||
// if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y ==0)
|
constexpr unsigned int MMA_M = 16;
|
||||||
// printf("conv2d_implicit_kernel launch BM:%d, BN:%d, BK:%d, WM:%d, WN:%d, WK:%d, NUM_THREADS:%d \n", BM, BN, BK, WM, WN, WK, NUM_THREADS);
|
constexpr unsigned int MMA_N = 8;
|
||||||
|
|
||||||
|
|
||||||
const unsigned int K = param.c * param.r * param.s;
|
const unsigned int K = param.c * param.r * param.s;
|
||||||
// const uint PQ = param.Oh * param.Ow;
|
|
||||||
const uint inChannelOffset = param.c * param.w;
|
const uint inChannelOffset = param.c * param.w;
|
||||||
const uint weightKOffset = param.c * param.r * param.s;
|
const uint weightKOffset = param.c * param.r * param.s;
|
||||||
|
|
||||||
// for convenience/readability in index calculations
|
|
||||||
// const unsigned int A_stride = K;
|
|
||||||
// const unsigned int B_stride = N;
|
|
||||||
// const unsigned int CD_stride = N;
|
|
||||||
|
|
||||||
// calculate how many bits of shared memory indices are going to be swizzled, and create masks
|
|
||||||
// constexpr unsigned int SWIZZLE_BITS_B = int_log2(BN / 8);
|
|
||||||
|
|
||||||
// loop bounds, constexpr where possible allows for loop unrolling
|
// loop bounds, constexpr where possible allows for loop unrolling
|
||||||
constexpr unsigned int mma_tiles_per_warp_k = 4;
|
constexpr unsigned int mma_tiles_per_warp_k = 4;
|
||||||
constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M;
|
constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M;
|
||||||
constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N;
|
constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N;
|
||||||
const unsigned int num_block_tiles_k = (K + (BK-1)) / BK;
|
const unsigned int num_block_tiles_k = (K + (BK-1)) / BK;
|
||||||
|
|
||||||
// calculate block/warp indices
|
// calculate block/warp indices
|
||||||
const unsigned int block_m = blockIdx.y;
|
const unsigned int block_m = blockIdx.y;
|
||||||
const unsigned int block_n = blockIdx.x;
|
const unsigned int block_n = blockIdx.x;
|
||||||
const unsigned int warp_m = threadIdx.y;
|
const unsigned int warp_m = threadIdx.y;
|
||||||
const unsigned int warp_n = threadIdx.x / 32;
|
const unsigned int warp_n = threadIdx.x / 32;
|
||||||
|
|
||||||
// double buffering
|
// double buffering
|
||||||
extern __shared__ half shmem[];
|
extern __shared__ half shmem[];
|
||||||
half* A_block_smem = shmem;
|
half* A_block_smem = shmem;
|
||||||
|
|
@ -1046,7 +703,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||||
// declare register storage
|
// declare register storage
|
||||||
// ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together
|
// ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together
|
||||||
uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2];
|
uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2];
|
||||||
// float acc_register_[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4];
|
|
||||||
uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2];
|
uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2];
|
||||||
uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n];
|
uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n];
|
||||||
|
|
||||||
|
|
@ -1056,10 +712,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||||
half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] = reinterpret_cast<half(&)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]>(B_register);
|
half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] = reinterpret_cast<half(&)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]>(B_register);
|
||||||
|
|
||||||
// accumulators start at 0
|
// accumulators start at 0
|
||||||
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
|
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){
|
||||||
{
|
for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){
|
||||||
for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++)
|
|
||||||
{
|
|
||||||
acc_register_[mma_m][mma_n][0] = 0;
|
acc_register_[mma_m][mma_n][0] = 0;
|
||||||
acc_register_[mma_m][mma_n][1] = 0;
|
acc_register_[mma_m][mma_n][1] = 0;
|
||||||
acc_register_[mma_m][mma_n][2] = 0;
|
acc_register_[mma_m][mma_n][2] = 0;
|
||||||
|
|
@ -1067,9 +721,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// these register arrays are used to cache values pre-fetched from global memory during the inner loop of the kernel
|
|
||||||
// the code is nicer if we hard code it for these tile dimensions and number of threads
|
|
||||||
// since we performing this copy with float4 pointers, for these tile dimensions it works out to be 8 float4s for A and 4 float4s for B
|
|
||||||
static_assert(BM == 256);
|
static_assert(BM == 256);
|
||||||
static_assert(BN == 256);
|
static_assert(BN == 256);
|
||||||
static_assert(BK == 32);
|
static_assert(BK == 32);
|
||||||
|
|
@ -1078,31 +729,19 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||||
float4 B_gmem_cache_reg[4];
|
float4 B_gmem_cache_reg[4];
|
||||||
|
|
||||||
// prefetch the first block tile of A,B into shared memory
|
// prefetch the first block tile of A,B into shared memory
|
||||||
// half* A_block_gmem = input + (block_m * BM * A_stride);
|
|
||||||
const half* A_block_gmem = input;
|
const half* A_block_gmem = input;
|
||||||
// const half* B_block_gmem = kernel + (block_n * weightKOffset);
|
|
||||||
const half* B_block_gmem = kernel + block_n * BN * weightKOffset;
|
const half* B_block_gmem = kernel + block_n * BN * weightKOffset;
|
||||||
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, inChannelOffset, param);
|
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, inChannelOffset, param);
|
||||||
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, weightKOffset, param);
|
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, weightKOffset, param);
|
||||||
|
|
||||||
// construct const pointers to warp tiles for use inside the inner loop
|
|
||||||
// if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x ==0 && blockIdx.y ==0){
|
|
||||||
// for(int i = 0; i < 32; ++i)
|
|
||||||
// printf("%.2f,", __half2float(A_block_smem[i]));
|
|
||||||
// printf("\n");
|
|
||||||
// }
|
|
||||||
|
|
||||||
int offset_direction = 1;
|
int offset_direction = 1;
|
||||||
|
|
||||||
for (unsigned int block_k = 1; block_k <= num_block_tiles_k; block_k++)
|
for (unsigned int block_k = 1; block_k <= num_block_tiles_k; block_k++){
|
||||||
{
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
if (block_k != num_block_tiles_k)
|
if (block_k != num_block_tiles_k){
|
||||||
{
|
|
||||||
// half* A_block_gmem = A + (block_m * BM * A_stride) + (block_k * BK);
|
|
||||||
const half* A_block_gmem = input;
|
const half* A_block_gmem = input;
|
||||||
// const half* B_block_gmem = kernel + (block_n * weightKOffset);
|
|
||||||
const half* B_block_gmem = kernel + (block_n * BN * weightKOffset);
|
const half* B_block_gmem = kernel + (block_n * BN * weightKOffset);
|
||||||
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param);
|
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param);
|
||||||
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param);
|
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param);
|
||||||
|
|
@ -1114,18 +753,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||||
ldmatrix_b<mma_tiles_per_warp_k, mma_tiles_per_warp_n, BK>(B_warp_tile, B_register_);
|
ldmatrix_b<mma_tiles_per_warp_k, mma_tiles_per_warp_n, BK>(B_warp_tile, B_register_);
|
||||||
|
|
||||||
// outer product between mma tiles
|
// outer product between mma tiles
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (unsigned int mma_k = 0; mma_k < mma_tiles_per_warp_k; mma_k++)
|
for (unsigned int mma_k = 0; mma_k < mma_tiles_per_warp_k; mma_k++){
|
||||||
{
|
#pragma unroll
|
||||||
#pragma unroll
|
for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){
|
||||||
for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++)
|
#pragma unroll
|
||||||
{
|
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){
|
||||||
#pragma unroll
|
|
||||||
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
|
|
||||||
{
|
|
||||||
asm volatile (
|
asm volatile (
|
||||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||||
// "mma.sync.aligned.m16n8k8.row.row.f16.f16.f16.f16 "
|
|
||||||
"{%0, %1}, "
|
"{%0, %1}, "
|
||||||
"{%2, %3}, "
|
"{%2, %3}, "
|
||||||
"{%4}, "
|
"{%4}, "
|
||||||
|
|
@ -1135,53 +770,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||||
"r"(B_register[mma_k][mma_n])
|
"r"(B_register[mma_k][mma_n])
|
||||||
"r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1])
|
"r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1])
|
||||||
);
|
);
|
||||||
// asm volatile (
|
|
||||||
// "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
|
||||||
// "{%0, %1, %2, %3},"
|
|
||||||
// "{%4, %5},"
|
|
||||||
// "{%6},"
|
|
||||||
// "{%7, %8, %9, %10};\n"
|
|
||||||
// : "=f"(acc_register_[mma_m][mma_n][0]), "=f"(acc_register_[mma_m][mma_n][1]),
|
|
||||||
// "=f"(acc_register_[mma_m][mma_n][2]), "=f"(acc_register_[mma_m][mma_n][3])
|
|
||||||
// : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]),
|
|
||||||
// "r"(B_register[mma_k][mma_n]),
|
|
||||||
// "f"(acc_register_[mma_m][mma_n][0]), "f"(acc_register_[mma_m][mma_n][1]),
|
|
||||||
// "f"(acc_register_[mma_m][mma_n][2]), "f"(acc_register_[mma_m][mma_n][3])
|
|
||||||
// );
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// if(threadIdx.x == 12 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){
|
|
||||||
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(acc_register_[0][0][0]), __half2float(acc_register_[0][0][1]),
|
|
||||||
// __half2float(acc_register_[0][0][2]), __half2float(acc_register_[0][0][3]));
|
|
||||||
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, acc_register_[0][0][0], acc_register_[0][0][1],
|
|
||||||
// acc_register_[0][0][2], acc_register_[0][0][3]);
|
|
||||||
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[0][mma_k][0]), __half2float(A_register_[0][mma_k][1]),
|
|
||||||
// __half2float(A_register_[0][mma_k][2]), __half2float(A_register_[0][mma_k][3]));
|
|
||||||
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]),
|
|
||||||
// __half2float(B_register_[mma_k][0][2]), __half2float(B_register_[mma_k][0][3]));
|
|
||||||
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, acc_register_[1][0][0], acc_register_[1][0][1],
|
|
||||||
// acc_register_[1][0][2], acc_register_[1][0][3]);
|
|
||||||
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[1][mma_k][0]), __half2float(A_register_[1][mma_k][1]),
|
|
||||||
// __half2float(A_register_[1][mma_k][2]), __half2float(A_register_[1][mma_k][3]));
|
|
||||||
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, acc_register_[3][0][0], acc_register_[3][0][1],
|
|
||||||
// acc_register_[3][0][2], acc_register_[3][0][3]);
|
|
||||||
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[3][mma_k][0]), __half2float(A_register_[3][mma_k][1]),
|
|
||||||
// __half2float(A_register_[3][mma_k][2]), __half2float(A_register_[3][mma_k][3]));
|
|
||||||
// printf(" %d, %d: %f, %f, \n", block_k, mma_k, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]));
|
|
||||||
// }
|
|
||||||
// if(threadIdx.x < 4 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){
|
|
||||||
// printf("A %d, %d, %d: %f, %f \n", block_k, mma_k, threadIdx.x, __half2float(A_register_[3][mma_k][0]), __half2float(A_register_[3][mma_k][1]));
|
|
||||||
// printf("B %d, %d, %d: %f, %f \n", block_k, mma_k, threadIdx.x, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]));
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
// if(threadIdx.x == 0 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){
|
|
||||||
// printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(acc_register_[3][0][0]), __half2float(acc_register_[3][0][1]),
|
|
||||||
// __half2float(acc_register_[3][0][2]), __half2float(acc_register_[3][0][3]));
|
|
||||||
// printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(A_register_[3][0][0]), __half2float(A_register_[3][0][1]),
|
|
||||||
// __half2float(A_register_[3][0][2]), __half2float(A_register_[3][0][3]));
|
|
||||||
// printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(B_register_[3][0][0]), __half2float(B_register_[3][0][1]),
|
|
||||||
// __half2float(B_register_[3][0][2]), __half2float(B_register_[3][0][3]));
|
|
||||||
// }
|
|
||||||
|
|
||||||
|
|
||||||
if (block_k != num_block_tiles_k)
|
if (block_k != num_block_tiles_k)
|
||||||
|
|
@ -1196,8 +787,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// reuse smem
|
// reuse smem
|
||||||
half *smemoutput = shmem;
|
half *smemoutput = shmem;
|
||||||
const uint lane_id = threadIdx.x % WARPSIZE;
|
const uint lane_id = threadIdx.x % WARPSIZE;
|
||||||
|
|
@ -1217,7 +806,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||||
{
|
{
|
||||||
for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++)
|
for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++)
|
||||||
{
|
{
|
||||||
// output sts
|
|
||||||
uint32_t (®_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
|
uint32_t (®_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
|
||||||
uint idx = output_sts_addr +
|
uint idx = output_sts_addr +
|
||||||
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
|
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
|
||||||
|
|
@ -1229,20 +817,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
// if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x ==0 && blockIdx.y ==0){
|
|
||||||
// for(int ii = 0; ii < 128; ++ii)
|
|
||||||
// printf("%.2f,", __half2float(smemoutput[ii]));
|
|
||||||
// printf("\n");
|
|
||||||
// for(int ii = 128; ii < 256; ++ii)
|
|
||||||
// printf("%.2f,", __half2float(smemoutput[ii]));
|
|
||||||
// printf("\n");
|
|
||||||
// for(int ii = 0; ii < 128; ++ii)
|
|
||||||
// printf("%.2f,", __half2float(smemoutput[ii*128]));
|
|
||||||
// printf("\n");
|
|
||||||
// for(int ii = 128; ii < 256; ++ii)
|
|
||||||
// printf("%.2f,", __half2float(smemoutput[ii*128]));
|
|
||||||
// printf("\n");
|
|
||||||
// }
|
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int subk = 0; subk < WN / 2; ++subk){
|
for (int subk = 0; subk < WN / 2; ++subk){
|
||||||
|
|
@ -1252,23 +826,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||||
const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||||
const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||||
if(n < param.n && row < param.k && col < param.Oh * param.Ow){
|
if(n < param.n && row < param.k && col < param.Oh * param.Ow){
|
||||||
// int outOffset = z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + (n_idx + j * 32);
|
|
||||||
// if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow)
|
|
||||||
// param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32];
|
|
||||||
const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
|
const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
|
||||||
uint idx = output_lds_addr + subk + j*32*BN/2;
|
uint idx = output_lds_addr + subk + j*32*BN/2;
|
||||||
idx = idx ^ ((idx & 0b1110000000) >> 4);
|
idx = idx ^ ((idx & 0b1110000000) >> 4);
|
||||||
// output[outOffset] = smemoutput[output_lds_addr + subk + j*32*BN/2];
|
|
||||||
output[outOffset] = smemoutput[idx];
|
output[outOffset] = smemoutput[idx];
|
||||||
// if(outOffset == 32){
|
|
||||||
// printf("(%u, %u, %u, %u), output[%d,%d,%d]=%f \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y,
|
|
||||||
// n, row, col, __half2float(output[outOffset]));
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(input);
|
GGML_UNUSED(input);
|
||||||
GGML_UNUSED(kernel);
|
GGML_UNUSED(kernel);
|
||||||
|
|
@ -1279,7 +844,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#define NUM_VARIANTS 6
|
#define NUM_VARIANTS 4
|
||||||
|
|
||||||
/*
|
/*
|
||||||
conv_shapes[][0]: ne_input=[384,512,256,1],ne_kernel=[3,3,256,256]
|
conv_shapes[][0]: ne_input=[384,512,256,1],ne_kernel=[3,3,256,256]
|
||||||
|
|
@ -1313,12 +878,10 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D,
|
||||||
int blockx = ((P.Oh * P.Ow + BM - 1) / BM); // blockx number
|
int blockx = ((P.Oh * P.Ow + BM - 1) / BM); // blockx number
|
||||||
int blocky = (P.k + BN-1) / BN; // blocky number
|
int blocky = (P.k + BN-1) / BN; // blocky number
|
||||||
int blockz = P.n; // blockz number
|
int blockz = P.n; // blockz number
|
||||||
// int threadx = NUM; // threadx number per block
|
|
||||||
int thready = 1; // thready number per block
|
int thready = 1; // thready number per block
|
||||||
int threadz = 1; // threadz number per block
|
int threadz = 1; // threadz number per block
|
||||||
dim3 thblock(NUM_THREADS, thready, threadz);
|
dim3 thblock(NUM_THREADS, thready, threadz);
|
||||||
dim3 grid(blockx, blocky, blockz);
|
dim3 grid(blockx, blocky, blockz);
|
||||||
// int smem_size = 24 * 1024;
|
|
||||||
if(P.c % 4 == 0){
|
if(P.c % 4 == 0){
|
||||||
if(P.layout == 0)
|
if(P.layout == 0)
|
||||||
conv2d_implicit_kernel<T, BM, BN, BK, WM, WN,
|
conv2d_implicit_kernel<T, BM, BN, BK, WM, WN,
|
||||||
|
|
@ -1337,9 +900,8 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D,
|
||||||
}
|
}
|
||||||
|
|
||||||
static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) {
|
static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) {
|
||||||
|
|
||||||
if (GGML_CUDA_CC_IS_NVIDIA(cc) && ampere_mma_available(cc) && P.layout == 0 && P.c % 8 == 0) {
|
if (GGML_CUDA_CC_IS_NVIDIA(cc) && ampere_mma_available(cc) && P.layout == 0 && P.c % 8 == 0) {
|
||||||
// #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
||||||
// printf("tensor core path called\n");
|
|
||||||
constexpr unsigned int BM_dim = 256;
|
constexpr unsigned int BM_dim = 256;
|
||||||
constexpr unsigned int BN_dim = 256;
|
constexpr unsigned int BN_dim = 256;
|
||||||
constexpr unsigned int BK_dim = 32;
|
constexpr unsigned int BK_dim = 32;
|
||||||
|
|
@ -1378,10 +940,6 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
|
||||||
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_D, Y_H.get(), P);
|
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_D, Y_H.get(), P);
|
||||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||||
to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st);
|
to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st);
|
||||||
// #else
|
|
||||||
// printf("non tensor path called\n");
|
|
||||||
// conv2d_implicit_cuda<half, 1>(X_D, K_D, Y_D, P, st);
|
|
||||||
// #endif
|
|
||||||
} else{
|
} else{
|
||||||
conv2d_implicit_cuda<half, 1>(X_D, K_D, Y_D, P, st);
|
conv2d_implicit_cuda<half, 1>(X_D, K_D, Y_D, P, st);
|
||||||
}
|
}
|
||||||
|
|
@ -1422,13 +980,6 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||||
// No cwhn
|
// No cwhn
|
||||||
GGML_ASSERT(p[7] == false);
|
GGML_ASSERT(p[7] == false);
|
||||||
|
|
||||||
// const int IW = input->ne[0]; // input_w
|
|
||||||
// const int IH = input->ne[1]; // input_h
|
|
||||||
// const int OW = dst->ne[0]; // output_w
|
|
||||||
// const int OH = dst->ne[1]; // output_h
|
|
||||||
// const int KW = kernel->ne[0]; // kernel_w
|
|
||||||
// const int KH = kernel->ne[1]; // kernel_h
|
|
||||||
// const int IC = input->ne[2]; // input_channels
|
|
||||||
const int IW = input->ne[LT == 0 ? 1 : 0]; // input_w
|
const int IW = input->ne[LT == 0 ? 1 : 0]; // input_w
|
||||||
const int IH = input->ne[LT == 0 ? 2 : 1]; // input_h
|
const int IH = input->ne[LT == 0 ? 2 : 1]; // input_h
|
||||||
const int OW = dst->ne[0]; // output_w
|
const int OW = dst->ne[0]; // output_w
|
||||||
|
|
|
||||||
|
|
@ -35,45 +35,8 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
|
||||||
half* dst,
|
half* dst,
|
||||||
const unsigned int src_stride,
|
const unsigned int src_stride,
|
||||||
param_t param
|
param_t param
|
||||||
)
|
){
|
||||||
{
|
|
||||||
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
|
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
|
||||||
// constexpr unsigned int SWIZZLE_MASK = 0b111 << SWIZZLE_BITS;
|
|
||||||
|
|
||||||
// // reinterpret input/output as float4
|
|
||||||
// float4* src_float4 = reinterpret_cast<float4*>(src);
|
|
||||||
// float4* dst_float4 = reinterpret_cast<float4*>(dst);
|
|
||||||
// const unsigned int src_stride_vectorized = src_stride / 8;
|
|
||||||
|
|
||||||
// // # of threads is multiple of # of columns in the tile
|
|
||||||
// constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
|
||||||
// static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
|
||||||
|
|
||||||
// // flatten out 2d grid of threads into in order of increasing threadIdx.x
|
|
||||||
// const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
|
||||||
|
|
||||||
// // assign each thread a row/column in the tile, calculate how many iterations we need
|
|
||||||
// // to cover the whole tile
|
|
||||||
// constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
|
|
||||||
// constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
|
||||||
// unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
|
|
||||||
// const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
|
|
||||||
|
|
||||||
// #pragma unroll
|
|
||||||
// for (unsigned int i = 0; i < NUM_ITERS; i++)
|
|
||||||
// {
|
|
||||||
// // apply swizzle to the dst index
|
|
||||||
// const unsigned int src_index = thread_row * src_stride_vectorized + thread_col;
|
|
||||||
// unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
|
||||||
// dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK) >> SWIZZLE_BITS);
|
|
||||||
// if (thread_col * 8 < param.k && start_k + innerColA * 4 < end_k){
|
|
||||||
// float4 tmp = reinterpret_cast<const float4 *>(&src[thread_row * src_stride_vectorized + thread_col*8)[0];
|
|
||||||
// dst_float4[dst_index] = src_float4[src_index];
|
|
||||||
// }else{ // read 4 halves
|
|
||||||
// dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
|
|
||||||
// }
|
|
||||||
// thread_row += ROW_STEP;
|
|
||||||
// }
|
|
||||||
|
|
||||||
constexpr unsigned int SWIZZLE_MASK_1 = 0b10000;
|
constexpr unsigned int SWIZZLE_MASK_1 = 0b10000;
|
||||||
constexpr unsigned int SWIZZLE_BITS_1 = 4;
|
constexpr unsigned int SWIZZLE_BITS_1 = 4;
|
||||||
|
|
@ -81,10 +44,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
|
||||||
constexpr unsigned int SWIZZLE_BITS_2 = 2;
|
constexpr unsigned int SWIZZLE_BITS_2 = 2;
|
||||||
constexpr unsigned int TILE_COLS = 32;
|
constexpr unsigned int TILE_COLS = 32;
|
||||||
|
|
||||||
// reinterpret input/output as float4
|
|
||||||
// float4* src_float4 = reinterpret_cast<float4*>(src);
|
|
||||||
float4* dst_float4 = reinterpret_cast<float4*>(dst);
|
float4* dst_float4 = reinterpret_cast<float4*>(dst);
|
||||||
// const unsigned int src_stride_vectorized = src_stride / 8;
|
|
||||||
|
|
||||||
// # of threads is multiple of # of columns in the tile
|
// # of threads is multiple of # of columns in the tile
|
||||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||||
|
|
@ -98,14 +58,12 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
|
||||||
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||||
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
|
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
|
||||||
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
|
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
|
||||||
// TODO: next block_k loop
|
|
||||||
const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
|
const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
|
||||||
const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||||
const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
|
const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (unsigned int i = 0; i < NUM_ITERS; i++)
|
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||||
{
|
|
||||||
// apply swizzle to the dst index
|
// apply swizzle to the dst index
|
||||||
const unsigned int src_index = thread_row * src_stride + thread_col * 8;
|
const unsigned int src_index = thread_row * src_stride + thread_col * 8;
|
||||||
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||||
|
|
@ -140,16 +98,14 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
|
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
|
||||||
|
|
||||||
constexpr unsigned int SWIZZLE_MASK_1 = 0b10000;
|
constexpr unsigned int SWIZZLE_MASK_1 = 0b10000;
|
||||||
constexpr unsigned int SWIZZLE_BITS_1 = 4;
|
constexpr unsigned int SWIZZLE_BITS_1 = 4;
|
||||||
constexpr unsigned int SWIZZLE_MASK_2 = 0b1100;
|
constexpr unsigned int SWIZZLE_MASK_2 = 0b1100;
|
||||||
constexpr unsigned int SWIZZLE_BITS_2 = 2;
|
constexpr unsigned int SWIZZLE_BITS_2 = 2;
|
||||||
constexpr unsigned int TILE_COLS = 32;
|
constexpr unsigned int TILE_COLS = 32;
|
||||||
|
|
||||||
// reinterpret input/output as float4
|
|
||||||
// float4* src_float4 = reinterpret_cast<float4*>(src);
|
|
||||||
float4* dst_float4 = reinterpret_cast<float4*>(dst);
|
float4* dst_float4 = reinterpret_cast<float4*>(dst);
|
||||||
// const unsigned int src_stride_vectorized = src_stride / 8;
|
|
||||||
|
|
||||||
// # of threads is multiple of # of columns in the tile
|
// # of threads is multiple of # of columns in the tile
|
||||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||||
|
|
@ -166,16 +122,13 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
||||||
|
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (unsigned int i = 0; i < NUM_ITERS; i++)
|
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||||
{
|
|
||||||
// unsigned int gemm_i = blockDim.y * TILE_ROWS + thread_row;
|
|
||||||
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
|
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
|
||||||
unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||||
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||||
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
|
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
|
||||||
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
|
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
|
||||||
unsigned int inOffset = n * param.c * param.h * param.w;
|
unsigned int inOffset = n * param.c * param.h * param.w;
|
||||||
// TODO: next block_k loop
|
|
||||||
const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
|
const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
|
||||||
const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||||
const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||||
|
|
@ -187,7 +140,6 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
||||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||||
curR < param.r && curS < param.s && curC < param.c){
|
curR < param.r && curS < param.s && curC < param.c){
|
||||||
// const unsigned int src_index = thread_row * src_stride_vectorized + thread_col;
|
|
||||||
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
|
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
|
||||||
} else{
|
} else{
|
||||||
|
|
@ -201,7 +153,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
||||||
GGML_UNUSED(inChannelOffset);
|
GGML_UNUSED(inChannelOffset);
|
||||||
GGML_UNUSED(param);
|
GGML_UNUSED(param);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<unsigned int TILE_ROWS,
|
template<unsigned int TILE_ROWS,
|
||||||
|
|
@ -215,17 +167,13 @@ __device__ __forceinline__ void tileMemcpyLoadA(
|
||||||
const unsigned int block_k,
|
const unsigned int block_k,
|
||||||
const unsigned int inChannelOffset,
|
const unsigned int inChannelOffset,
|
||||||
param_t param
|
param_t param
|
||||||
)
|
){
|
||||||
{
|
|
||||||
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
|
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
|
||||||
// reinterpret input/output as float4
|
|
||||||
// const float4* src_float4 = reinterpret_cast<const float4*>(src);
|
|
||||||
// const unsigned int src_stride_vectorized = src_stride / 8;
|
|
||||||
|
|
||||||
// # of threads is multiple of # of columns in the tile
|
// # of threads is multiple of # of columns in the tile
|
||||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||||
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||||
|
|
||||||
// flatten out 2d grid of threads into in order of increasing threadIdx.x
|
// flatten out 2d grid of threads into in order of increasing threadIdx.x
|
||||||
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
|
@ -240,19 +188,13 @@ __device__ __forceinline__ void tileMemcpyLoadA(
|
||||||
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (unsigned int i = 0; i < NUM_ITERS; i++)
|
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||||
{
|
|
||||||
// const unsigned int src_index = thread_row * src_stride_vectorized + thread_col;
|
|
||||||
// dst_reg[i] = src_float4[src_index];
|
|
||||||
// thread_row += ROW_STEP;
|
|
||||||
// unsigned int gemm_i = blockDim.y * TILE_ROWS + thread_row;
|
|
||||||
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
|
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
|
||||||
unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||||
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||||
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
|
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
|
||||||
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
|
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
|
||||||
unsigned int inOffset = n * param.c * param.h * param.w;
|
unsigned int inOffset = n * param.c * param.h * param.w;
|
||||||
// TODO: next block_k loop
|
|
||||||
const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset
|
const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset
|
||||||
const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||||
const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||||
|
|
@ -260,7 +202,6 @@ __device__ __forceinline__ void tileMemcpyLoadA(
|
||||||
int curW = posw_ori + curS * param.d_w; // input w
|
int curW = posw_ori + curS * param.d_w; // input w
|
||||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||||
curR < param.r && curS < param.s && curC < param.c){
|
curR < param.r && curS < param.s && curC < param.c){
|
||||||
// const unsigned int src_index = thread_row * src_stride_vectorized + thread_col;
|
|
||||||
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
|
dst_reg[i] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
|
||||||
} else{
|
} else{
|
||||||
|
|
@ -289,17 +230,13 @@ __device__ __forceinline__ void tileMemcpyLoadB(
|
||||||
const unsigned int block_k,
|
const unsigned int block_k,
|
||||||
const unsigned int src_stride,
|
const unsigned int src_stride,
|
||||||
param_t param
|
param_t param
|
||||||
)
|
){
|
||||||
{
|
|
||||||
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
|
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
|
||||||
// reinterpret input/output as float4
|
|
||||||
// const float4* src_float4 = reinterpret_cast<const float4*>(src);
|
|
||||||
// const unsigned int src_stride_vectorized = src_stride / 8;
|
|
||||||
|
|
||||||
// # of threads is multiple of # of columns in the tile
|
// # of threads is multiple of # of columns in the tile
|
||||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||||
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||||
|
|
||||||
// flatten out 2d grid of threads into in order of increasing threadIdx.x
|
// flatten out 2d grid of threads into in order of increasing threadIdx.x
|
||||||
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
|
@ -318,11 +255,7 @@ __device__ __forceinline__ void tileMemcpyLoadB(
|
||||||
const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
|
const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (unsigned int i = 0; i < NUM_ITERS; i++)
|
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||||
{
|
|
||||||
// const unsigned int src_index = thread_row * src_stride_vectorized + thread_col;
|
|
||||||
// dst_reg[i] = src_float4[src_index];
|
|
||||||
// thread_row += ROW_STEP;
|
|
||||||
const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8;
|
const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8;
|
||||||
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){
|
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){
|
||||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||||
|
|
@ -338,7 +271,7 @@ __device__ __forceinline__ void tileMemcpyLoadB(
|
||||||
GGML_UNUSED(src_stride);
|
GGML_UNUSED(src_stride);
|
||||||
GGML_UNUSED(param);
|
GGML_UNUSED(param);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -354,6 +287,7 @@ __device__ __forceinline__ void tileMemcpySwizzleStore(
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
|
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
|
||||||
|
|
||||||
constexpr unsigned int SWIZZLE_MASK_1 = 0b10000;
|
constexpr unsigned int SWIZZLE_MASK_1 = 0b10000;
|
||||||
constexpr unsigned int SWIZZLE_BITS_1 = 4;
|
constexpr unsigned int SWIZZLE_BITS_1 = 4;
|
||||||
constexpr unsigned int SWIZZLE_MASK_2 = 0b1100;
|
constexpr unsigned int SWIZZLE_MASK_2 = 0b1100;
|
||||||
|
|
@ -392,9 +326,9 @@ __device__ __forceinline__ void tileMemcpySwizzleStore(
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(src_reg);
|
GGML_UNUSED(src_reg);
|
||||||
GGML_UNUSED(dst);
|
GGML_UNUSED(dst);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) {
|
__device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) {
|
||||||
|
|
@ -409,15 +343,6 @@ __device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) {
|
||||||
return address;
|
return address;
|
||||||
}
|
}
|
||||||
|
|
||||||
// constexpr unsigned int int_log2(unsigned int x)
|
|
||||||
// {
|
|
||||||
// unsigned int result = 0;
|
|
||||||
// while (x >>= 1)
|
|
||||||
// {
|
|
||||||
// result++;
|
|
||||||
// }
|
|
||||||
// return result;
|
|
||||||
// }
|
|
||||||
|
|
||||||
#define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256
|
#define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256
|
||||||
void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
|
||||||
|
|
@ -6807,33 +6807,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||||
GGML_TYPE_F32, 1, 1, p0, p1, 1, 1, false));
|
GGML_TYPE_F32, 1, 1, p0, p1, 1, 1, false));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto act_case : cases_sd) {
|
|
||||||
GGML_ASSERT(act_case[idx_sd["kw"]] == 3 || act_case[idx_sd["kw"]] == 1);
|
|
||||||
GGML_ASSERT(act_case[idx_sd["kh"]] == 3 || act_case[idx_sd["kh"]] == 1);
|
|
||||||
|
|
||||||
uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0;
|
|
||||||
uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0;
|
|
||||||
|
|
||||||
test_cases.emplace_back(new test_conv_2d_implicit(
|
|
||||||
{ act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] },
|
|
||||||
{ act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] },
|
|
||||||
GGML_TYPE_F16, 1, 1, p0, p1, 1, 1, true));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto act_case : cases_sd) {
|
|
||||||
GGML_ASSERT(act_case[idx_sd["kw"]] == 3 || act_case[idx_sd["kw"]] == 1);
|
|
||||||
GGML_ASSERT(act_case[idx_sd["kh"]] == 3 || act_case[idx_sd["kh"]] == 1);
|
|
||||||
|
|
||||||
uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0;
|
|
||||||
uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0;
|
|
||||||
|
|
||||||
test_cases.emplace_back(new test_conv_2d_implicit(
|
|
||||||
{ act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] },
|
|
||||||
{ act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] },
|
|
||||||
GGML_TYPE_F32, 1, 1, p0, p1, 1, 1, true));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));
|
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));
|
||||||
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
|
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue