WIP: tensore code compiled ok
This commit is contained in:
parent
2715341c1d
commit
80a996cfc0
|
|
@ -730,7 +730,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
}
|
||||
}
|
||||
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
|
||||
|
||||
template <unsigned int mma_tiles_per_warp_m, unsigned int mma_tiles_per_warp_k, unsigned int smem_stride>
|
||||
__device__ __forceinline__ void ldmatrix_a(
|
||||
|
|
@ -738,6 +738,7 @@ __device__ __forceinline__ void ldmatrix_a(
|
|||
half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]
|
||||
)
|
||||
{
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 4");
|
||||
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
|
||||
|
||||
|
|
@ -880,7 +881,11 @@ __device__ __forceinline__ void ldmatrix_a(
|
|||
: "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1])
|
||||
: "r"(src_addr + 96 * smem_stride_)
|
||||
);
|
||||
|
||||
#else
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(reg);
|
||||
NO_DEVICE_CODE;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <unsigned int mma_tiles_per_warp_k, unsigned int mma_tiles_per_warp_n, unsigned int smem_stride>
|
||||
|
|
@ -889,10 +894,11 @@ __device__ __forceinline__ void ldmatrix_b(
|
|||
half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]
|
||||
)
|
||||
{
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
|
||||
static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8");
|
||||
|
||||
// uint32_t (®_) [4][8] = reinterpret_cast<uint32_t(&)[4][8]>(reg);
|
||||
uint32_t (®_) [4][8] = reinterpret_cast<uint32_t(&)[4][8]>(reg);
|
||||
// const unsigned int logical_offset = ((threadIdx.x % 8) * smem_stride) + (((threadIdx.x % 32) / 8) * 8);
|
||||
// unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b11100000000) >> 5);
|
||||
// uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset);
|
||||
|
|
@ -979,15 +985,20 @@ __device__ __forceinline__ void ldmatrix_b(
|
|||
// : "r"(src_addr ^ 0b1000000)
|
||||
: "r"(src_addr + 32 * smem_stride_)
|
||||
);
|
||||
|
||||
#else
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(reg);
|
||||
NO_DEVICE_CODE;
|
||||
#endif
|
||||
}
|
||||
|
||||
template<const int BM, const int BN, const int BK, const int WM, const int WN,
|
||||
const int WK, const int NUM_THREADS>
|
||||
static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||
static __global__ void conv2d_implicit_kernel_tc(const half * __restrict__ input,
|
||||
const half * __restrict__ kernel,
|
||||
half * __restrict__ output,
|
||||
const param_t param) {
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
constexpr unsigned int MMA_M = 16;
|
||||
constexpr unsigned int MMA_N = 8;
|
||||
|
||||
|
|
@ -997,18 +1008,18 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
const uint weightKOffset = param.c * param.r * param.s;
|
||||
|
||||
// for convenience/readability in index calculations
|
||||
const unsigned int A_stride = K;
|
||||
const unsigned int B_stride = N;
|
||||
const unsigned int CD_stride = N;
|
||||
// const unsigned int A_stride = K;
|
||||
// const unsigned int B_stride = N;
|
||||
// const unsigned int CD_stride = N;
|
||||
|
||||
// calculate how many bits of shared memory indices are going to be swizzled, and create masks
|
||||
constexpr unsigned int SWIZZLE_BITS_B = int_log2(BN / 8);
|
||||
// constexpr unsigned int SWIZZLE_BITS_B = int_log2(BN / 8);
|
||||
|
||||
// loop bounds, constexpr where possible allows for loop unrolling
|
||||
constexpr unsigned int mma_tiles_per_warp_k = 4;
|
||||
constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M;
|
||||
constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N;
|
||||
const unsigned int num_block_tiles_k = K / BK;
|
||||
const unsigned int num_block_tiles_k = (K + (BK-1)) / BK;
|
||||
|
||||
// calculate block/warp indices
|
||||
const unsigned int block_m = blockIdx.y;
|
||||
|
|
@ -1057,8 +1068,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
|
||||
// prefetch the first block tile of A,B into shared memory
|
||||
// half* A_block_gmem = input + (block_m * BM * A_stride);
|
||||
half* A_block_gmem = input;
|
||||
half* B_block_gmem = kernel + (block_n * weightKOffset);
|
||||
const half* A_block_gmem = input;
|
||||
const half* B_block_gmem = kernel + (block_n * weightKOffset);
|
||||
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, inChannelOffset, param);
|
||||
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, weightKOffset, param);
|
||||
|
||||
|
|
@ -1074,10 +1085,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
if (block_k != num_block_tiles_k)
|
||||
{
|
||||
// half* A_block_gmem = A + (block_m * BM * A_stride) + (block_k * BK);
|
||||
half* A_block_gmem = input;
|
||||
half* B_block_gmem = kernel + (block_n * weightKOffset);
|
||||
tileMemcpyLoad<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param);
|
||||
tileMemcpyLoad<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param);
|
||||
const half* A_block_gmem = input;
|
||||
const half* B_block_gmem = kernel + (block_n * weightKOffset);
|
||||
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param);
|
||||
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param);
|
||||
}
|
||||
half* A_warp_tile = A_block_smem + (warp_m * WM * BK);
|
||||
half* B_warp_tile = B_block_smem + (warp_n * WN * BK);
|
||||
|
|
@ -1124,14 +1135,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
}
|
||||
|
||||
// reuse smem
|
||||
half *smemoutput = reinterpret_cast<half *>(shmem);
|
||||
half *smemoutput = shmem;
|
||||
const uint lane_id = threadIdx.x % WARPSIZE;
|
||||
const uint mma_row = lane_id / 4;
|
||||
const uint mma_col = lane_id % 4;
|
||||
const uint output_lds_addr = warp_id * WSUBM * WSUBN + lane_id;
|
||||
const uint output_lds_addr = warp_m * WM * BN/2 + lane_id * BN/2 + warp_n * WN/2;
|
||||
const uint output_sts_addr = warp_m * WM * BN/2 + mma_row * BN/2 + warp_n * WN/2 + mma_col * 2;
|
||||
const uint m_idx = by * BN + mma_tid_y * WN;
|
||||
const uint n_idx = block_m * BM + warp_m * WM;
|
||||
const uint m_idx = block_n * BN + warp_n * WN;
|
||||
const uint n_idx = block_m * BM + warp_m * WM + lane_id;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++i)
|
||||
|
|
@ -1142,72 +1153,42 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
{
|
||||
// output sts
|
||||
uint32_t (®_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
|
||||
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[output_sts_addr +
|
||||
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N]);
|
||||
const uint idx = output_sts_addr +
|
||||
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
|
||||
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
|
||||
dst_ptr[0] = reg_[0];
|
||||
dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[output_sts_addr +
|
||||
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N + 8 * BN / 2]);
|
||||
dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx + 8 * BN / 2]);
|
||||
dst_ptr[0] = reg_[1];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
const uint row = m_idx + j * WSUBN + (lane_id + subk * WARPSIZE) / WSUBM;
|
||||
const uint gemm_i = n_idx + i * WSUBM + (lane_id + subk * WARPSIZE) % WSUBM;
|
||||
const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||
const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
if (n < param.n && row < param.k && col < param.Oh * param.Ow){
|
||||
// int outOffset = z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + (n_idx + j * 32);
|
||||
// if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow)
|
||||
// param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32];
|
||||
const uint outOffset = ksplit > 0 ?
|
||||
z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow +
|
||||
row * param.Oh * param.Ow + col :
|
||||
z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
|
||||
output[outOffset] = smemoutput[output_lds_addr + subk * WARPSIZE];
|
||||
for (int subk = 0; subk < WN / 2; ++subk){
|
||||
for (int j = 0; j < 4; ++j){
|
||||
const uint row = m_idx + subk + i * WN / 2;
|
||||
const uint gemm_i = n_idx + j*32;
|
||||
const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||
const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
if(n < param.n && row < param.k && col < param.Oh * param.Ow){
|
||||
// int outOffset = z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + (n_idx + j * 32);
|
||||
// if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow)
|
||||
// param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32];
|
||||
const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
|
||||
output[outOffset] = smemoutput[output_lds_addr + subk + j*32*BN/2];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//////////////
|
||||
// epilogue //
|
||||
//////////////
|
||||
// half alpha_ = (half)alpha;
|
||||
// half beta_ = (half)beta;
|
||||
// half C_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4];
|
||||
|
||||
// // calculate pointers for this warps C and D tiles
|
||||
// half* C_block_gmem = C + (block_m * BM_dim * CD_stride) + (block_n * BN_dim);
|
||||
// half* C_warp_gmem = C_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim);
|
||||
// half* D_block_gmem = D + (block_m * BM_dim * CD_stride) + (block_n * BN_dim);
|
||||
// half* D_warp_gmem = D_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim);
|
||||
|
||||
// for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
|
||||
// {
|
||||
// for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++)
|
||||
// {
|
||||
// half* C_mma_tile = C_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim);
|
||||
// ldmatrix_m16n8_gmem(C_mma_tile, C_register[mma_m][mma_n], N * sizeof(half));
|
||||
|
||||
// // scale C by beta
|
||||
// acc_register_[mma_m][mma_n][0] = acc_register_[mma_m][mma_n][0] * alpha_ + C_register[mma_m][mma_n][0] * beta_;
|
||||
// acc_register_[mma_m][mma_n][1] = acc_register_[mma_m][mma_n][1] * alpha_ + C_register[mma_m][mma_n][1] * beta_;
|
||||
// acc_register_[mma_m][mma_n][2] = acc_register_[mma_m][mma_n][2] * alpha_ + C_register[mma_m][mma_n][2] * beta_;
|
||||
// acc_register_[mma_m][mma_n][3] = acc_register_[mma_m][mma_n][3] * alpha_ + C_register[mma_m][mma_n][3] * beta_;
|
||||
// }
|
||||
// }
|
||||
|
||||
// for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
|
||||
// {
|
||||
// for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++)
|
||||
// {
|
||||
// half* D_mma_tile = D_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim);
|
||||
// stmatrix_m16n8(D_mma_tile, acc_register_[mma_m][mma_n], N * sizeof(half));
|
||||
// }
|
||||
// }
|
||||
|
||||
#else
|
||||
GGML_UNUSED(input);
|
||||
GGML_UNUSED(kernel);
|
||||
GGML_UNUSED(output);
|
||||
GGML_UNUSED(param);
|
||||
NO_DEVICE_CODE;
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#define NUM_VARIANTS 6
|
||||
|
||||
|
|
@ -1266,11 +1247,53 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D,
|
|||
}
|
||||
}
|
||||
|
||||
static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t P, cudaStream_t st) {
|
||||
static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) {
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
if (GGML_CUDA_CC_IS_NVIDIA(cc) && ampere_mma_available(cc) && P.layout == 0 && P.c % 8 == 0) {
|
||||
constexpr unsigned int BM_dim = 256;
|
||||
constexpr unsigned int BN_dim = 256;
|
||||
constexpr unsigned int BK_dim = 32;
|
||||
|
||||
constexpr unsigned int WARPS_PER_BLOCK_M = 2;
|
||||
constexpr unsigned int WARPS_PER_BLOCK_N = 4;
|
||||
constexpr unsigned int WARPS_PER_BLOCK_K = 4;
|
||||
|
||||
constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M;
|
||||
constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N;
|
||||
constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K;
|
||||
const unsigned int BlocksM = (P.n * P.Oh * P.Ow + BM_dim - 1) / BM_dim;
|
||||
const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim;
|
||||
constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M;
|
||||
constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N;
|
||||
constexpr unsigned int NumThreads = ThreadsM * ThreadsN;
|
||||
const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
|
||||
dim3 gridDim(BlocksN, BlocksM);
|
||||
dim3 blockDim(ThreadsN, ThreadsM);
|
||||
|
||||
int id = ggml_cuda_get_device();
|
||||
ggml_cuda_pool_alloc<half> x_f16(ctx.pool(id));
|
||||
|
||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(GGML_TYPE_F32);
|
||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||
size_t ne = P.c * P.h * P.w * P.n;
|
||||
x_f16.alloc(ne);
|
||||
to_fp16_cuda(X_D, x_f16.get(), ne, st);
|
||||
const half *X_H = x_f16.get();
|
||||
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n);
|
||||
conv2d_implicit_kernel_tc<BM_dim, BN_dim, BK_dim,
|
||||
WM_dim, WN_dim, WK_dim, NumThreads>
|
||||
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_D, Y_H.get(), P);
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st);
|
||||
}else{
|
||||
conv2d_implicit_cuda<half, 1>(X_D, K_D, Y_D, P, st);
|
||||
}
|
||||
#else
|
||||
conv2d_implicit_cuda<half, 1>(X_D, K_D, Y_D, P, st);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t P, cudaStream_t st) {
|
||||
static void conv2d_implicit_cuda_f32(ggml_backend_cuda_context & ctx, const float * X_D, const float * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) {
|
||||
conv2d_implicit_cuda<float, 1>(X_D, K_D, Y_D, P, st);
|
||||
}
|
||||
|
||||
|
|
@ -1286,6 +1309,7 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
|
||||
|
||||
cudaStream_t st = ctx.stream();
|
||||
const int cc = ggml_cuda_info().devices[ctx.device].cc;
|
||||
|
||||
const int32_t * p = (const int32_t *) dst->op_params;
|
||||
const int ST_X = p[0]; // stride_x
|
||||
|
|
@ -1333,8 +1357,8 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
params.layout = LT;
|
||||
|
||||
if (kernel->type == GGML_TYPE_F16) {
|
||||
conv2d_implicit_cuda_f16(X_D, (half *) K_D, Y_D, params, st);
|
||||
conv2d_implicit_cuda_f16(ctx, X_D, (half *) K_D, Y_D, cc, params, st);
|
||||
} else {
|
||||
conv2d_implicit_cuda_f32(X_D, K_D, Y_D, params, st);
|
||||
conv2d_implicit_cuda_f32(ctx, X_D, K_D, Y_D, cc, params, st);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,9 +31,10 @@ typedef struct{
|
|||
template<unsigned int TILE_ROWS,
|
||||
unsigned int NUM_THREADS>
|
||||
__device__ __forceinline__ void tileMemcpySwizzleB(
|
||||
half* src,
|
||||
const half* src,
|
||||
half* dst,
|
||||
const unsigned int src_stride
|
||||
const unsigned int src_stride,
|
||||
param_t param
|
||||
)
|
||||
{
|
||||
// constexpr unsigned int SWIZZLE_MASK = 0b111 << SWIZZLE_BITS;
|
||||
|
|
@ -109,7 +110,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
|
|||
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
if (thread_row < param.k && curR < param.R && curS < param.S && curC < param.c){
|
||||
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
}else{ // read 4 halves
|
||||
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
|
|
@ -123,7 +124,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
|
|||
template<unsigned int TILE_ROWS,
|
||||
unsigned int NUM_THREADS>
|
||||
__device__ __forceinline__ void tileMemcpySwizzleA(
|
||||
half* src,
|
||||
const half* src,
|
||||
half* dst,
|
||||
// const unsigned int src_stride,
|
||||
const unsigned int inChannelOffset,
|
||||
|
|
@ -175,7 +176,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
|||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||
curR < param.R && curS < param.S && curC < param.c){
|
||||
curR < param.r && curS < param.s && curC < param.c){
|
||||
// const unsigned int src_index = thread_row * src_stride_vectorized + thread_col;
|
||||
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
|
||||
|
|
@ -191,7 +192,7 @@ unsigned int TILE_COLS,
|
|||
unsigned int NUM_THREADS,
|
||||
unsigned int ELEMENTS_PER_THREAD>
|
||||
__device__ __forceinline__ void tileMemcpyLoadA(
|
||||
half* src,
|
||||
const half* src,
|
||||
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
|
||||
// const unsigned int src_stride,
|
||||
const unsigned int block_k,
|
||||
|
|
@ -200,7 +201,7 @@ __device__ __forceinline__ void tileMemcpyLoadA(
|
|||
)
|
||||
{
|
||||
// reinterpret input/output as float4
|
||||
float4* src_float4 = reinterpret_cast<float4*>(src);
|
||||
// const float4* src_float4 = reinterpret_cast<const float4*>(src);
|
||||
// const unsigned int src_stride_vectorized = src_stride / 8;
|
||||
|
||||
// # of threads is multiple of # of columns in the tile
|
||||
|
|
@ -239,7 +240,7 @@ __device__ __forceinline__ void tileMemcpyLoadA(
|
|||
int curH = posh_ori + curR * param.d_h; // input h
|
||||
int curW = posw_ori + curS * param.d_w; // input w
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||
curR < param.R && curS < param.S && curC < param.c){
|
||||
curR < param.r && curS < param.s && curC < param.c){
|
||||
// const unsigned int src_index = thread_row * src_stride_vectorized + thread_col;
|
||||
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
|
||||
|
|
@ -256,7 +257,7 @@ unsigned int TILE_COLS,
|
|||
unsigned int NUM_THREADS,
|
||||
unsigned int ELEMENTS_PER_THREAD>
|
||||
__device__ __forceinline__ void tileMemcpyLoadB(
|
||||
half* src,
|
||||
const half* src,
|
||||
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
|
||||
const unsigned int block_k,
|
||||
const unsigned int src_stride,
|
||||
|
|
@ -264,7 +265,7 @@ __device__ __forceinline__ void tileMemcpyLoadB(
|
|||
)
|
||||
{
|
||||
// reinterpret input/output as float4
|
||||
float4* src_float4 = reinterpret_cast<float4*>(src);
|
||||
// const float4* src_float4 = reinterpret_cast<const float4*>(src);
|
||||
// const unsigned int src_stride_vectorized = src_stride / 8;
|
||||
|
||||
// # of threads is multiple of # of columns in the tile
|
||||
|
|
@ -295,7 +296,7 @@ __device__ __forceinline__ void tileMemcpyLoadB(
|
|||
// dst_reg[i] = src_float4[src_index];
|
||||
// thread_row += ROW_STEP;
|
||||
const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8;
|
||||
if (thread_row < param.k && curR < param.R && curS < param.S && curC < param.c){
|
||||
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
}else{ // read 4 halves
|
||||
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
|
|
@ -449,5 +450,15 @@ __device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) {
|
|||
|
||||
#endif
|
||||
|
||||
// constexpr unsigned int int_log2(unsigned int x)
|
||||
// {
|
||||
// unsigned int result = 0;
|
||||
// while (x >>= 1)
|
||||
// {
|
||||
// result++;
|
||||
// }
|
||||
// return result;
|
||||
// }
|
||||
|
||||
#define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256
|
||||
void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
|
|||
Loading…
Reference in New Issue