fix passing param as reference

This commit is contained in:
bssrdf 2025-09-03 12:45:19 -04:00
parent 4d772873b9
commit 3877608dc0
3 changed files with 118 additions and 8 deletions

View File

@ -25,9 +25,9 @@ template <typename T>
static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
const T * __restrict__ kernel,
float * __restrict__ output,
const param_t &param) {
const param_t param) {
extern __shared__ __align__(16 * 1024) char smem[];
extern __shared__ __align__(16 * 1024) char smem[];
T *smemweight = reinterpret_cast<T *>(smem);
float *smeminput = reinterpret_cast<float *>(smem + 16 * 1024);
@ -35,6 +35,12 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
int bx = blockIdx.x;
int by = blockIdx.y;
// if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){
// printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow);
// // printf("param.n=%d\n",param.n);
// }
// __syncthreads();
// Warp tile
const int lane_id = threadIdx.x % 32;
const int warp_id = threadIdx.x / 32;
@ -85,6 +91,10 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
}
}
// ldg
// if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){
// printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow);
// }
// __syncthreads();
#pragma unroll
for (int i = 0; i < 4; ++i)
{
@ -282,11 +292,10 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
}
}
}
}
template <typename T>
static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t &P, cudaStream_t st) {
static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) {
// const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
int blockx = ((P.Oh * P.Ow + 127) / 128); // blockx number
int blocky = (P.k + 127) / 128; // blocky number
@ -300,11 +309,11 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D,
conv2d_implicit_kernel<T><<<grid, thblock, smem_size, st>>>(X_D, K_D, Y_D, P);
}
static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t &P, cudaStream_t st) {
static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t P, cudaStream_t st) {
conv2d_implicit_cuda<half>(X_D, K_D, Y_D, P, st);
}
static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t &P, cudaStream_t st) {
static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t P, cudaStream_t st) {
conv2d_implicit_cuda<float>(X_D, K_D, Y_D, P, st);
}
@ -343,9 +352,9 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
const int IC = input->ne[2]; // input_channels
const int OC = kernel->ne[3]; // ouptut_chanles
const int B = input->ne[3]; // n_batches
const int64_t total = B * OC * OH * OW;
// param_t params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };
param_t params = { B, IC, IH, IW, OC, KH, KW, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, OH, OW };
if (kernel->type == GGML_TYPE_F16) {

View File

@ -975,6 +975,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"IM2COL",
"IM2COL_BACK",
"CONV_2D",
"CONV_2D_IMPLICIT",
"CONV_3D",
"CONV_2D_DW",
"CONV_TRANSPOSE_2D",
@ -1078,6 +1079,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"im2col(x)",
"im2col_back(x)",
"conv_2d(x)",
"conv_2d_implicit(x)",
"conv_3d(x)",
"conv_2d_dw(x)",
"conv_transpose_2d(x)",

View File

@ -4116,6 +4116,94 @@ struct test_conv_2d : public test_case {
}
};
// CONV_2D_IMPLICIT
struct test_conv_2d_implicit : public test_case {
const std::array<int64_t, 4> ne_input;
const std::array<int64_t, 4> ne_kernel;
const ggml_type type_kernel;
const int stride0;
const int stride1;
const int padding0;
const int padding1;
const int dilation0;
const int dilation1;
// Whether the inputs are contiguous in the channel dim or the width dim
const bool cwhn;
std::string vars() override {
return VARS_TO_STR10(ne_input, ne_kernel, type_kernel, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn);
}
double max_nmse_err() override {
return 5e-4;
}
uint64_t op_flops(ggml_tensor * t) override {
GGML_UNUSED(t);
// Just counting matmul costs:
// KxCRS @ CRSxNPQ = KxNPQ --> KxNPQx(CRS+CRS-1) flops
// Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
};
int64_t W = ne_input[0];
int64_t H = ne_input[1];
int64_t KW = ne_kernel[0];
int64_t KH = ne_kernel[1];
int64_t Cin = ne_kernel[2];
int64_t Cout = ne_kernel[3];
int64_t N = ne_input[3];
int64_t OH = calc_conv_output_size(H, KH, stride0, padding0, dilation0);
int64_t OW = calc_conv_output_size(W, KW, stride0, padding0, dilation0);
int64_t K = Cout;
int64_t CRS = Cin * KH * KW;
int64_t NPQ = N * OH * OW;
return K * NPQ * (2 * CRS - 1);
}
test_conv_2d_implicit(std::array<int64_t, 4> ne_input = { 64, 64, 16, 1 },
std::array<int64_t, 4> ne_kernel = { 3, 3, 1, 16 }, ggml_type type_kernel = GGML_TYPE_F32, int stride0 = 1,
int stride1 = 1, int padding0 = 0, int padding1 = 0, int dilation0 = 1, int dilation1 = 1, bool cwhn = false) :
ne_input(ne_input),
ne_kernel(ne_kernel),
type_kernel(type_kernel),
stride0(stride0),
stride1(stride1),
padding0(padding0),
padding1(padding1),
dilation0(dilation0),
dilation1(dilation1),
cwhn(cwhn) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
ggml_set_name(input, "input");
ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
ggml_set_name(kernel, "kernel");
if (cwhn) {
// change memory layout to channel-most-contiguous (CWHN),
// then permute it back so NE matches the original input
input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3));
input = ggml_permute(ctx, input, 2, 0, 1, 3);
kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0));
kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1);
}
ggml_tensor * out =
ggml_conv_2d_implicitgemm(ctx, kernel, input, stride0, stride1, padding0, padding1, dilation0, dilation1);
ggml_set_name(out, "out");
return out;
}
};
// GGML_OP_CONV_2D_DW
struct test_conv_2d_dw : public test_case {
const std::array<int64_t, 4> ne_input;
@ -6454,6 +6542,17 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
}
}
for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (auto act_case : cases) {
// Direct CONV_2D
test_cases.emplace_back(new test_conv_2d_implicit(
{ act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
{ act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },
kernel_type, 1, 1, 0, 0, 1, 1, false));
}
}
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));