make CI happy
This commit is contained in:
parent
2b5351a898
commit
c1f67c19e0
|
|
@ -85,8 +85,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
// Warp tile
|
// Warp tile
|
||||||
const uint lane_id = tx % WARPSIZE;
|
const uint lane_id = tx % WARPSIZE;
|
||||||
const uint warp_id = tx / WARPSIZE;
|
const uint warp_id = tx / WARPSIZE;
|
||||||
const int mma_tid_x = warp_id / (BN / WN);
|
const int mma_tid_x = warp_id / (BN / WN);
|
||||||
const int mma_tid_y = warp_id % (BN / WN);
|
const int mma_tid_y = warp_id % (BN / WN);
|
||||||
|
|
||||||
// size of the warp subtile
|
// size of the warp subtile
|
||||||
constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER);
|
constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER);
|
||||||
|
|
@ -449,7 +449,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
const int n = (ksplit > 0) ? gemm_i / PQ : z;
|
const int n = (ksplit > 0) ? gemm_i / PQ : z;
|
||||||
const int col = (ksplit > 0) ? gemm_i % PQ : gemm_i;
|
const int col = (ksplit > 0) ? gemm_i % PQ : gemm_i;
|
||||||
if (n < param.n && row < param.k && col < param.Oh * param.Ow){
|
if (n < param.n && row < param.k && col < param.Oh * param.Ow){
|
||||||
const uint outOffset = ksplit > 0 ?
|
const uint outOffset = ksplit > 0 ?
|
||||||
z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow +
|
z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow +
|
||||||
row * param.Oh * param.Ow + col :
|
row * param.Oh * param.Ow + col :
|
||||||
z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
|
z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
|
||||||
|
|
@ -626,7 +626,7 @@ __device__ __forceinline__ void ldmatrix_b(
|
||||||
|
|
||||||
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
|
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
|
||||||
static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8");
|
static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8");
|
||||||
|
|
||||||
uint32_t (®_) [4][8] = reinterpret_cast<uint32_t(&)[4][8]>(reg);
|
uint32_t (®_) [4][8] = reinterpret_cast<uint32_t(&)[4][8]>(reg);
|
||||||
unsigned int logical_offset = (threadIdx.x % 32) * smem_stride;
|
unsigned int logical_offset = (threadIdx.x % 32) * smem_stride;
|
||||||
unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4);
|
unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4);
|
||||||
|
|
@ -739,11 +739,11 @@ constexpr unsigned int MMA_N = 8;
|
||||||
constexpr int BUFFER_SIZE = BM * BK + BK * BN;
|
constexpr int BUFFER_SIZE = BM * BK + BK * BN;
|
||||||
|
|
||||||
// declare register storage
|
// declare register storage
|
||||||
// ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together
|
// ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together
|
||||||
uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2];
|
uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2];
|
||||||
uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2];
|
uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2];
|
||||||
uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n];
|
uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n];
|
||||||
|
|
||||||
// convenience cast to half for register storage
|
// convenience cast to half for register storage
|
||||||
half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast<half(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4]>(acc_register);
|
half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast<half(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4]>(acc_register);
|
||||||
half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast<half(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]>(A_register);
|
half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast<half(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]>(A_register);
|
||||||
|
|
@ -827,7 +827,7 @@ constexpr unsigned int MMA_N = 8;
|
||||||
|
|
||||||
// reuse smem
|
// reuse smem
|
||||||
half *smemoutput = shmem;
|
half *smemoutput = shmem;
|
||||||
const uint lane_id = threadIdx.x % WARPSIZE;
|
const uint lane_id = threadIdx.x % WARPSIZE;
|
||||||
const uint mma_row = lane_id / 4;
|
const uint mma_row = lane_id / 4;
|
||||||
const uint mma_col = lane_id % 4;
|
const uint mma_col = lane_id % 4;
|
||||||
const uint output_lds_addr = warp_m * WM * BN/2 + lane_id * BN/2 + warp_n * WN/2;
|
const uint output_lds_addr = warp_m * WM * BN/2 + lane_id * BN/2 + warp_n * WN/2;
|
||||||
|
|
@ -845,7 +845,7 @@ constexpr unsigned int MMA_N = 8;
|
||||||
for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++)
|
for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++)
|
||||||
{
|
{
|
||||||
uint32_t (®_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
|
uint32_t (®_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
|
||||||
uint idx = output_sts_addr +
|
uint idx = output_sts_addr +
|
||||||
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
|
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
|
||||||
idx = idx ^ ((idx & 0b1110000000) >> 4);
|
idx = idx ^ ((idx & 0b1110000000) >> 4);
|
||||||
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
|
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
|
||||||
|
|
@ -902,7 +902,7 @@ constexpr static int conv_shapes[][NUM_VARIANTS] = {
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, unsigned int CONV_SHAPE>
|
template <typename T, unsigned int CONV_SHAPE>
|
||||||
static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) {
|
static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) {
|
||||||
|
|
||||||
const uint BM = conv_shapes[0][CONV_SHAPE];
|
const uint BM = conv_shapes[0][CONV_SHAPE];
|
||||||
const uint BN = conv_shapes[1][CONV_SHAPE];
|
const uint BN = conv_shapes[1][CONV_SHAPE];
|
||||||
|
|
@ -920,7 +920,7 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D,
|
||||||
int threadz = 1; // threadz number per block
|
int threadz = 1; // threadz number per block
|
||||||
dim3 thblock(NUM_THREADS, thready, threadz);
|
dim3 thblock(NUM_THREADS, thready, threadz);
|
||||||
dim3 grid(blockx, blocky, blockz);
|
dim3 grid(blockx, blocky, blockz);
|
||||||
|
|
||||||
conv2d_implicit_kernel<T, BM, BN, BK, WM, WN,
|
conv2d_implicit_kernel<T, BM, BN, BK, WM, WN,
|
||||||
WNITER, TM, TN, NUM_THREADS, 1, false, 0><<<grid, thblock, 0, st>>>(X_D, K_D, Y_D, P);
|
WNITER, TM, TN, NUM_THREADS, 1, false, 0><<<grid, thblock, 0, st>>>(X_D, K_D, Y_D, P);
|
||||||
}
|
}
|
||||||
|
|
@ -991,6 +991,7 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
|
||||||
static void conv2d_implicit_cuda_f32(ggml_backend_cuda_context & ctx, const float * X_D, const float * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) {
|
static void conv2d_implicit_cuda_f32(ggml_backend_cuda_context & ctx, const float * X_D, const float * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) {
|
||||||
conv2d_implicit_cuda<float, 1>(X_D, K_D, Y_D, P, st);
|
conv2d_implicit_cuda<float, 1>(X_D, K_D, Y_D, P, st);
|
||||||
GGML_UNUSED(ctx);
|
GGML_UNUSED(ctx);
|
||||||
|
GGML_UNUSED(cc);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
|
|
||||||
|
|
@ -137,7 +137,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
||||||
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||||
curR < param.r && curS < param.s && curC < param.c){
|
curR < param.r && curS < param.s && curC < param.c){
|
||||||
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
|
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
|
||||||
|
|
@ -199,7 +199,7 @@ __device__ __forceinline__ void tileMemcpyLoadA(
|
||||||
const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||||
int curH = posh_ori + curR * param.d_h; // input h
|
int curH = posh_ori + curR * param.d_h; // input h
|
||||||
int curW = posw_ori + curS * param.d_w; // input w
|
int curW = posw_ori + curS * param.d_w; // input w
|
||||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||||
curR < param.r && curS < param.s && curC < param.c){
|
curR < param.r && curS < param.s && curC < param.c){
|
||||||
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
|
dst_reg[i] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
|
||||||
|
|
@ -215,7 +215,7 @@ __device__ __forceinline__ void tileMemcpyLoadA(
|
||||||
GGML_UNUSED(inChannelOffset);
|
GGML_UNUSED(inChannelOffset);
|
||||||
GGML_UNUSED(param);
|
GGML_UNUSED(param);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -299,7 +299,7 @@ __device__ __forceinline__ void tileMemcpySwizzleStore(
|
||||||
// # of threads is multiple of # of columns in the tile
|
// # of threads is multiple of # of columns in the tile
|
||||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||||
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||||
|
|
||||||
// flatten out 2d grid of threads into in order of increasing threadIdx.x
|
// flatten out 2d grid of threads into in order of increasing threadIdx.x
|
||||||
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
|
@ -312,7 +312,7 @@ __device__ __forceinline__ void tileMemcpySwizzleStore(
|
||||||
|
|
||||||
// compile time check that we provided the right amount of registers for storage
|
// compile time check that we provided the right amount of registers for storage
|
||||||
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (unsigned int i = 0; i < NUM_ITERS; i++)
|
for (unsigned int i = 0; i < NUM_ITERS; i++)
|
||||||
{
|
{
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue