Add implicit GEMM convolution operation for 2D tensors in CUDA
This commit is contained in:
parent
9961d244f2
commit
8a589317b6
|
|
@ -512,6 +512,7 @@ extern "C" {
|
|||
GGML_OP_IM2COL,
|
||||
GGML_OP_IM2COL_BACK,
|
||||
GGML_OP_CONV_2D,
|
||||
GGML_OP_CONV_2D_IMPLICIT,
|
||||
GGML_OP_CONV_3D,
|
||||
GGML_OP_CONV_2D_DW,
|
||||
GGML_OP_CONV_TRANSPOSE_2D,
|
||||
|
|
@ -1941,6 +1942,17 @@ extern "C" {
|
|||
int d0, // dilation dimension 0
|
||||
int d1); // dilation dimension 1
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_2d_implicitgemm(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
|
||||
struct ggml_tensor * b, // input data [W, H, C, N]
|
||||
int s0, // stride dimension 0
|
||||
int s1, // stride dimension 1
|
||||
int p0, // padding dimension 0
|
||||
int p1, // padding dimension 1
|
||||
int d0, // dilation dimension 0
|
||||
int d1); // dilation dimension 1
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_3d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,394 @@
|
|||
#include "conv2d-implicit.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
struct conv_params {
|
||||
const int64_t IW, IH;
|
||||
const int64_t OW, OH;
|
||||
const int64_t KW, KH;
|
||||
const int64_t ST_X, ST_Y;
|
||||
const int64_t PD_X, PD_Y;
|
||||
const int64_t DL_X, DL_Y;
|
||||
const int64_t IC, OC;
|
||||
const int64_t B;
|
||||
const int64_t TOTAL;
|
||||
};
|
||||
|
||||
struct kernel_bounds {
|
||||
int64_t y_min, y_max;
|
||||
int64_t x_min, x_max;
|
||||
};
|
||||
|
||||
__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) {
|
||||
return (a > b) ? a : b;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) {
|
||||
return (a < b) ? a : b;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) {
|
||||
kernel_bounds bounds;
|
||||
bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
|
||||
bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
|
||||
bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
|
||||
bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
|
||||
return bounds;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int calculate_input_coord(int64_t out_coord,
|
||||
int64_t kern_coord,
|
||||
int64_t stride,
|
||||
int64_t dilation,
|
||||
int64_t padding) {
|
||||
return out_coord * stride + kern_coord * dilation - padding;
|
||||
}
|
||||
|
||||
struct whcn_layout {
|
||||
__device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
|
||||
return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x;
|
||||
}
|
||||
|
||||
__device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) {
|
||||
return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx;
|
||||
}
|
||||
|
||||
__device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
|
||||
return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x;
|
||||
}
|
||||
|
||||
__device__ static void unpack_indices(int64_t global_idx,
|
||||
const conv_params & P,
|
||||
int64_t & n,
|
||||
int64_t & c,
|
||||
int64_t & out_y,
|
||||
int64_t & out_x) {
|
||||
out_x = global_idx % P.OW;
|
||||
out_y = (global_idx / P.OW) % P.OH;
|
||||
c = (global_idx / (P.OW * P.OH)) % P.OC;
|
||||
n = global_idx / (P.OW * P.OH * P.OC);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Layout>
|
||||
static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||
const T * __restrict__ kernel,
|
||||
float * __restrict__ output,
|
||||
const conv_params P) {
|
||||
|
||||
__shared__ __align__(16 * 1024) char smem[24 * 1024];
|
||||
T *smemweight = reinterpret_cast<T *>(smem);
|
||||
float *smeminput = reinterpret_cast<float *>(smem + 16 * 1024);
|
||||
|
||||
int tx = threadIdx.x;
|
||||
int bx = blockIdx.x;
|
||||
int by = blockIdx.y;
|
||||
|
||||
// Warp tile
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int warp_id = threadIdx.x / 32;
|
||||
const int mma_tid_x = (lane_id / 2) % 8;
|
||||
const int mma_tid_y = (lane_id / 16) * 2 + (lane_id % 2);
|
||||
// lds addr
|
||||
int weight_lds_addr = (warp_id / 2) * 32 + mma_tid_y * 4;
|
||||
int input_lds_addr = (warp_id % 2) * 64 + mma_tid_x * 4;
|
||||
|
||||
int x = bx * 128 + input_lds_addr;
|
||||
int y = by * 128 + weight_lds_addr;
|
||||
int z = blockIdx.z;
|
||||
|
||||
T weight_ldg_reg[4];
|
||||
float input_ldg_reg[4];
|
||||
|
||||
int posh_ori[4];
|
||||
int posw_ori[4];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
posh_ori[i] = ((bx * 128 + tx % 32 + i * 32) / param.Ow) * param.u - param.p;
|
||||
posw_ori[i] = ((bx * 128 + tx % 32 + i * 32) % param.Ow) * param.v - param.q;
|
||||
}
|
||||
|
||||
int inOffset = z * param.c * param.h * param.w;
|
||||
int weiOffset = (by * 128 + tx / 8 * 4) * param.c * param.r * param.s;
|
||||
int inChannelOffset = param.h * param.w;
|
||||
int weightChannelOffset = param.r * param.s;
|
||||
int weightKOffset = param.c * param.r * param.s;
|
||||
|
||||
// sts addr
|
||||
int weight_sts_addr = (tx % 8) * 132 +
|
||||
(tx / 8) * 4;
|
||||
int input_sts_addr = (tx / 32) * 128 + (tx % 32);
|
||||
|
||||
int write_flag = 1;
|
||||
T weight_frag[2][8];
|
||||
float input_frag[2][8];
|
||||
float output_frag[8][8];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; ++i)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 8; ++j)
|
||||
{
|
||||
output_frag[i][j] = 0;
|
||||
}
|
||||
}
|
||||
// ldg
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k)
|
||||
{
|
||||
weight_ldg_reg[i] = kernel[weiOffset + tx % 8 + i * weightKOffset];
|
||||
}
|
||||
else
|
||||
{
|
||||
weight_ldg_reg[i] = (T)0.f;
|
||||
}
|
||||
}
|
||||
int curC = (tx / 32) / (param.r * param.s); // channel offset
|
||||
int curR = ((tx / 32) % (param.r * param.s)) / param.s; // kernel r offset
|
||||
int curS = ((tx / 32) % (param.r * param.s)) % param.s; // kernel s offset
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
int curH = posh_ori[i] + curR; // input h
|
||||
int curW = posw_ori[i] + curS; // input w
|
||||
int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW;
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h)
|
||||
{
|
||||
input_ldg_reg[i] = input[inOffset + inOffsetTmp];
|
||||
}
|
||||
else
|
||||
{
|
||||
input_ldg_reg[i] = 0.0;
|
||||
}
|
||||
}
|
||||
// sts
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
smemweight[weight_sts_addr + i] = weight_ldg_reg[i];
|
||||
}
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
smeminput[input_sts_addr + i * 32] = input_ldg_reg[i];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
// lds
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
weight_frag[0][i] = smemweight[weight_lds_addr + i];
|
||||
weight_frag[0][i + 4] = smemweight[weight_lds_addr + i + 16];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
input_frag[0][i] = smeminput[input_lds_addr + i];
|
||||
input_frag[0][i + 4] = smeminput[input_lds_addr + i + 32];
|
||||
}
|
||||
for (int crs = 0; crs < param.r * param.s * param.c; crs += 8)
|
||||
{
|
||||
// ldg
|
||||
int weiOffsetTmp = crs + 8 + tx % 8;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k)
|
||||
{
|
||||
weight_ldg_reg[i] = kernel[weiOffset + weiOffsetTmp + i * weightKOffset];
|
||||
}
|
||||
else
|
||||
{
|
||||
weight_ldg_reg[i] = (T)0.f;
|
||||
}
|
||||
}
|
||||
curC = (crs + 8 + tx / 32) / (param.r * param.s); // channel offset
|
||||
curR = ((crs + 8 + tx / 32) % (param.r * param.s)) / param.s; // kernel r offset
|
||||
curS = ((crs + 8 + tx / 32) % (param.r * param.s)) % param.s; // kernel s offset
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
int curH = posh_ori[i] + curR; // input h
|
||||
int curW = posw_ori[i] + curS; // input w
|
||||
int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW;
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h)
|
||||
{
|
||||
input_ldg_reg[i] = input[inOffset + inOffsetTmp];
|
||||
}
|
||||
else
|
||||
{
|
||||
input_ldg_reg[i] = 0.f;
|
||||
}
|
||||
}
|
||||
int load_flag = write_flag ^ 1;
|
||||
#pragma unroll
|
||||
for (int subcrs = 0; subcrs < 8 - 1; ++subcrs)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
weight_frag[(subcrs + 1) % 2][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i];
|
||||
weight_frag[(subcrs + 1) % 2][i + 4] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i + 16];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
input_frag[(subcrs + 1) % 2][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i];
|
||||
input_frag[(subcrs + 1) % 2][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; ++i)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 8; ++j)
|
||||
{
|
||||
output_frag[i][j] += ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]) * input_frag[subcrs % 2][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
// sts
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
smemweight[write_flag * 132 * 8 + weight_sts_addr + i] = weight_ldg_reg[i];
|
||||
}
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
smeminput[write_flag * 128 * 8 + input_sts_addr + i * 32] = input_ldg_reg[i];
|
||||
}
|
||||
__syncthreads();
|
||||
write_flag ^= 1;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
weight_frag[0][i] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i];
|
||||
weight_frag[0][i + 4] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i + 16];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
input_frag[0][i] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i];
|
||||
input_frag[0][i + 4] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i + 32];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; ++i)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 8; ++j)
|
||||
{
|
||||
output_frag[i][j] += ggml_cuda_cast<float>(weight_frag[1][i]) * input_frag[1][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reuse smem
|
||||
float *smemoutput = reinterpret_cast<float *>(smem);
|
||||
// float *smembias = reinterpret_cast<float *>(smem + 16 * 1024);
|
||||
|
||||
// bias ldg/sts
|
||||
// if (tx < 128)
|
||||
// {
|
||||
// smembias[tx] = param.bias[by * 128 + tx];
|
||||
// }
|
||||
|
||||
uint32_t output_sts_addr = warp_id * 512 + mma_tid_y * 4 * 8 * 4 + mma_tid_x * 4;
|
||||
uint32_t output_lds_addr = warp_id * 512 + lane_id;
|
||||
// uint32_t bias_lds_addr = warp_id / 2 * 32;
|
||||
|
||||
uint32_t m_idx = blockIdx.y * 128 + warp_id / 2 * 32;
|
||||
uint32_t n_idx = blockIdx.x * 128 + warp_id % 2 * 64 + lane_id;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++i)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; ++j)
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int subi = 0; subi < 4; ++subi)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int subj = 0; subj < 4; ++subj)
|
||||
{
|
||||
// output sts
|
||||
smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int subk = 0; subk < 16; ++subk)
|
||||
{
|
||||
int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32;
|
||||
if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow)
|
||||
// output[outOffset] = smemoutput[output_lds_addr + subk * 32] + smembias[bias_lds_addr + i * 16 + subk];
|
||||
output[outOffset] = smemoutput[output_lds_addr + subk * 32];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
|
||||
const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
|
||||
conv2d_implicit_kernel<T, whcn_layout><<<blocks, CUDA_CONV2D_BLOCK_SIZE, 0, st>>>(X_D, K_D, Y_D, P);
|
||||
}
|
||||
|
||||
static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
|
||||
conv2d_implicit_cuda<half>(X_D, K_D, Y_D, P, st);
|
||||
}
|
||||
|
||||
static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
|
||||
conv2d_implicit_cuda<float>(X_D, K_D, Y_D, P, st);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * kernel = dst->src[0];
|
||||
const ggml_tensor * input = dst->src[1];
|
||||
float * K_D = (float *) kernel->data;
|
||||
const float * X_D = (const float *) input->data;
|
||||
float * Y_D = (float *) dst->data;
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(kernel));
|
||||
GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32);
|
||||
|
||||
// same number of input channels
|
||||
GGML_ASSERT(input->ne[2] == kernel->ne[2]);
|
||||
|
||||
cudaStream_t st = ctx.stream();
|
||||
|
||||
const int32_t * p = (const int32_t *) dst->op_params;
|
||||
const int ST_X = p[0]; // stride_x
|
||||
const int ST_Y = p[1]; // stride_y
|
||||
const int PD_X = p[2]; // padding_x
|
||||
const int PD_Y = p[3]; // padding_y
|
||||
const int DL_X = p[4]; // dilation_x
|
||||
const int DL_Y = p[5]; // dilation_y
|
||||
|
||||
// No cwhn
|
||||
GGML_ASSERT(p[6] == false);
|
||||
|
||||
const int IW = input->ne[0]; // input_w
|
||||
const int IH = input->ne[1]; // input_h
|
||||
const int OW = dst->ne[0]; // output_w
|
||||
const int OH = dst->ne[1]; // output_h
|
||||
const int KW = kernel->ne[0]; // kernel_w
|
||||
const int KH = kernel->ne[1]; // kernel_h
|
||||
const int IC = input->ne[2]; // input_channels
|
||||
const int OC = kernel->ne[3]; // ouptut_chanles
|
||||
const int B = input->ne[3]; // n_batches
|
||||
|
||||
const int64_t total = B * OC * OH * OW;
|
||||
conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };
|
||||
|
||||
if (kernel->type == GGML_TYPE_F16) {
|
||||
conv2d_implicit_cuda_f16(X_D, (half *) K_D, Y_D, params, st);
|
||||
} else {
|
||||
conv2d_implicit_cuda_f32(X_D, K_D, Y_D, params, st);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
#pragma once
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256
|
||||
void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -4482,6 +4482,45 @@ struct ggml_tensor * ggml_conv_2d_direct(
|
|||
return result;
|
||||
}
|
||||
|
||||
|
||||
// ggml_conv_2d_implicitgemm
|
||||
|
||||
struct ggml_tensor * ggml_conv_2d_implicitgemm(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
|
||||
struct ggml_tensor * b, // input data [W, H, C, N]
|
||||
int s0, // stride dimension 0
|
||||
int s1, // stride dimension 1
|
||||
int p0, // padding dimension 0
|
||||
int p1, // padding dimension 1
|
||||
int d0, // dilation dimension 0
|
||||
int d1) {// dilation dimension 1
|
||||
|
||||
GGML_ASSERT(a->ne[2] == b->ne[2]);
|
||||
//GGML_ASSERT(a->type == b->type);
|
||||
|
||||
int64_t ne[4];
|
||||
ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
|
||||
ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
|
||||
ne[2] = a->ne[3];
|
||||
ne[3] = b->ne[3];
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, s0);
|
||||
ggml_set_op_params_i32(result, 1, s1);
|
||||
ggml_set_op_params_i32(result, 2, p0);
|
||||
ggml_set_op_params_i32(result, 3, p1);
|
||||
ggml_set_op_params_i32(result, 4, d0);
|
||||
ggml_set_op_params_i32(result, 5, d1);
|
||||
|
||||
result->op = GGML_OP_CONV_2D_IMPLICIT;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_conv_3d
|
||||
|
||||
struct ggml_tensor * ggml_conv_3d(
|
||||
|
|
|
|||
Loading…
Reference in New Issue