vulkan: support im2col_3d (#15795)
This commit is contained in:
parent
d36e61c580
commit
3976dfbe00
|
|
@ -554,6 +554,7 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_argmax_f32;
|
||||
vk_pipeline pipeline_count_equal_i32;
|
||||
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
|
||||
vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
|
||||
vk_pipeline pipeline_timestep_embedding_f32;
|
||||
vk_pipeline pipeline_conv_transpose_1d_f32;
|
||||
vk_pipeline pipeline_pool2d_f32;
|
||||
|
|
@ -982,6 +983,37 @@ struct vk_op_im2col_push_constants {
|
|||
int32_t d0; int32_t d1;
|
||||
};
|
||||
|
||||
struct vk_op_im2col_3d_push_constants {
|
||||
uint32_t nb10;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
uint32_t s0;
|
||||
uint32_t s1;
|
||||
uint32_t s2;
|
||||
uint32_t p0;
|
||||
uint32_t p1;
|
||||
uint32_t p2;
|
||||
uint32_t d0;
|
||||
uint32_t d1;
|
||||
uint32_t d2;
|
||||
uint32_t IW;
|
||||
uint32_t IH;
|
||||
uint32_t ID;
|
||||
uint32_t IC;
|
||||
uint32_t KW;
|
||||
uint32_t OH;
|
||||
uint32_t KD_KH_KW;
|
||||
uint32_t KH_KW;
|
||||
uint32_t IC_KD_KH_KW;
|
||||
uint32_t N_OD_OH;
|
||||
uint32_t OD_OH;
|
||||
uint32_t OD_OH_OW_IC_KD_KH_KW;
|
||||
uint32_t OH_OW_IC_KD_KH_KW;
|
||||
uint32_t OW_IC_KD_KH_KW;
|
||||
uint32_t misalign_offsets;
|
||||
};
|
||||
|
||||
struct vk_op_timestep_embedding_push_constants {
|
||||
uint32_t nb1;
|
||||
uint32_t dim;
|
||||
|
|
@ -3380,10 +3412,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
|
||||
if (device->float_controls_rte_fp16) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
|
||||
} else {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
|
||||
|
|
@ -7717,6 +7752,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return ctx->device->pipeline_im2col_f32_f16;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_IM2COL_3D:
|
||||
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_im2col_3d_f32;
|
||||
}
|
||||
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
|
||||
return ctx->device->pipeline_im2col_3d_f32_f16;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_timestep_embedding_f32;
|
||||
|
|
@ -7832,6 +7875,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
|||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_IM2COL_3D:
|
||||
case GGML_OP_SET_ROWS:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
|
|
@ -7890,6 +7934,16 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
|
|||
GGML_UNUSED(src2);
|
||||
}
|
||||
|
||||
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
|
||||
const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
|
||||
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
|
||||
|
||||
p.misalign_offsets = (a_offset << 16) | d_offset;
|
||||
|
||||
GGML_UNUSED(src0);
|
||||
GGML_UNUSED(src2);
|
||||
}
|
||||
|
||||
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
|
||||
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
|
||||
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
|
||||
|
|
@ -8130,6 +8184,26 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
|
||||
elements = { OW * KW * KH, OH, batch * IC };
|
||||
} break;
|
||||
case GGML_OP_IM2COL_3D:
|
||||
{
|
||||
const uint32_t IC = ((const uint32_t *)(dst->op_params))[9];
|
||||
|
||||
const uint32_t N = ne13 / IC;
|
||||
|
||||
const uint32_t KD = ne02;
|
||||
const uint32_t KH = ne01;
|
||||
const uint32_t KW = ne00;
|
||||
|
||||
const uint32_t OD = ned3 / N;
|
||||
const uint32_t OH = ned2;
|
||||
const uint32_t OW = ned1;
|
||||
|
||||
const uint32_t IC_KD_KH_KW = IC*KD*KH*KW;
|
||||
const uint32_t N_OD_OH = N*OD*OH;
|
||||
|
||||
elements = { IC_KD_KH_KW, OW, N_OD_OH };
|
||||
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
||||
} break;
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
{
|
||||
const uint32_t dim = dst->op_params[0];
|
||||
|
|
@ -8286,7 +8360,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
}
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
} else if (op == GGML_OP_IM2COL) {
|
||||
} else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
|
||||
// im2col uses only src1 and dst buffers
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
} else if (op == GGML_OP_COUNT_EQUAL) {
|
||||
|
|
@ -9147,6 +9221,66 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
}, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
||||
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
||||
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
|
||||
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
|
||||
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
|
||||
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
|
||||
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
|
||||
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
|
||||
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
|
||||
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
|
||||
|
||||
const int64_t N = ne13 / IC;
|
||||
const int64_t ID = ne12;
|
||||
const int64_t IH = ne11;
|
||||
const int64_t IW = ne10;
|
||||
|
||||
const int64_t KD = ne02;
|
||||
const int64_t KH = ne01;
|
||||
const int64_t KW = ne00;
|
||||
|
||||
const int64_t OD = ne3 / N;
|
||||
const int64_t OH = ne2;
|
||||
const int64_t OW = ne1;
|
||||
|
||||
vk_op_im2col_3d_push_constants pc {};
|
||||
|
||||
pc.nb10 = nb10 / ggml_type_size(src1->type);
|
||||
pc.nb11 = nb11 / ggml_type_size(src1->type);
|
||||
pc.nb12 = nb12 / ggml_type_size(src1->type);
|
||||
pc.nb13 = nb13 / ggml_type_size(src1->type);
|
||||
pc.s0 = s0;
|
||||
pc.s1 = s1;
|
||||
pc.s2 = s2;
|
||||
pc.p0 = p0;
|
||||
pc.p1 = p1;
|
||||
pc.p2 = p2;
|
||||
pc.d0 = d0;
|
||||
pc.d1 = d1;
|
||||
pc.d2 = d2;
|
||||
pc.IW = IW;
|
||||
pc.IH = IH;
|
||||
pc.ID = ID;
|
||||
pc.IC = IC;
|
||||
pc.KW = KW;
|
||||
pc.OH = OH;
|
||||
pc.KD_KH_KW = KD*KH*KW;
|
||||
pc.KH_KW = KH*KW;
|
||||
pc.IC_KD_KH_KW = IC*KD*KH*KW;
|
||||
pc.N_OD_OH = N*OD*OH;
|
||||
pc.OD_OH = OD*OH;
|
||||
pc.OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
|
||||
pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
|
||||
pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
|
||||
|
||||
ggml_vk_op_f32<vk_op_im2col_3d_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t dim = dst->op_params[0];
|
||||
const uint32_t max_period = dst->op_params[1];
|
||||
|
|
@ -10352,6 +10486,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_IM2COL_3D:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
case GGML_OP_POOL_2D:
|
||||
|
|
@ -10422,6 +10557,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_IM2COL_3D:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
case GGML_OP_POOL_2D:
|
||||
|
|
@ -10717,6 +10853,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
case GGML_OP_IM2COL:
|
||||
ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_IM2COL_3D:
|
||||
ggml_vk_im2col_3d(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
|
@ -10868,6 +11008,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_IM2COL_3D:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
case GGML_OP_POOL_2D:
|
||||
|
|
@ -12150,6 +12291,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_IM2COL_3D:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_POOL_2D:
|
||||
|
|
@ -12725,6 +12867,19 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
|
||||
const bool is_2D = tensor->op_params[6] == 1;
|
||||
tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
|
||||
} else if (tensor->op == GGML_OP_IM2COL_3D) {
|
||||
const int32_t s0 = tensor->op_params[0];
|
||||
const int32_t s1 = tensor->op_params[1];
|
||||
const int32_t s1 = tensor->op_params[2];
|
||||
const int32_t p0 = tensor->op_params[3];
|
||||
const int32_t p1 = tensor->op_params[4];
|
||||
const int32_t p1 = tensor->op_params[5];
|
||||
const int32_t d0 = tensor->op_params[6];
|
||||
const int32_t d1 = tensor->op_params[7];
|
||||
const int32_t d1 = tensor->op_params[8];
|
||||
const int32_t IC = tensor->op_params[9];
|
||||
|
||||
tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type);
|
||||
} else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
|
||||
const int32_t dim = tensor->op_params[0];
|
||||
const int32_t max_period = tensor->op_params[1];
|
||||
|
|
|
|||
|
|
@ -0,0 +1,112 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#include "rte.comp"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint32_t nb10;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
uint32_t s0;
|
||||
uint32_t s1;
|
||||
uint32_t s2;
|
||||
uint32_t p0;
|
||||
uint32_t p1;
|
||||
uint32_t p2;
|
||||
uint32_t d0;
|
||||
uint32_t d1;
|
||||
uint32_t d2;
|
||||
uint32_t IW;
|
||||
uint32_t IH;
|
||||
uint32_t ID;
|
||||
uint32_t IC;
|
||||
uint32_t KW;
|
||||
uint32_t OH;
|
||||
uint32_t KD_KH_KW;
|
||||
uint32_t KH_KW;
|
||||
uint32_t IC_KD_KH_KW;
|
||||
uint32_t N_OD_OH;
|
||||
uint32_t OD_OH;
|
||||
uint32_t OD_OH_OW_IC_KD_KH_KW;
|
||||
uint32_t OH_OW_IC_KD_KH_KW;
|
||||
uint32_t OW_IC_KD_KH_KW;
|
||||
uint32_t misalign_offsets;
|
||||
} p;
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
||||
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint32_t i = gl_GlobalInvocationID.x;
|
||||
|
||||
uint32_t nb10 = p.nb10;
|
||||
uint32_t nb11 = p.nb11;
|
||||
uint32_t nb12 = p.nb12;
|
||||
uint32_t nb13 = p.nb13;
|
||||
uint32_t s0 = p.s0;
|
||||
uint32_t s1 = p.s1;
|
||||
uint32_t s2 = p.s2;
|
||||
uint32_t p0 = p.p0;
|
||||
uint32_t p1 = p.p1;
|
||||
uint32_t p2 = p.p2;
|
||||
uint32_t d0 = p.d0;
|
||||
uint32_t d1 = p.d1;
|
||||
uint32_t d2 = p.d2;
|
||||
uint32_t IW = p.IW;
|
||||
uint32_t IH = p.IH;
|
||||
uint32_t ID = p.ID;
|
||||
uint32_t IC = p.IC;
|
||||
uint32_t KW = p.KW;
|
||||
uint32_t OH = p.OH;
|
||||
uint32_t KD_KH_KW = p.KD_KH_KW;
|
||||
uint32_t KH_KW = p.KH_KW;
|
||||
uint32_t IC_KD_KH_KW = p.IC_KD_KH_KW;
|
||||
uint32_t N_OD_OH = p.N_OD_OH;
|
||||
uint32_t OD_OH = p.OD_OH;
|
||||
uint32_t OD_OH_OW_IC_KD_KH_KW = p.OD_OH_OW_IC_KD_KH_KW;
|
||||
uint32_t OH_OW_IC_KD_KH_KW = p.OH_OW_IC_KD_KH_KW;
|
||||
uint32_t OW_IC_KD_KH_KW = p.OW_IC_KD_KH_KW;
|
||||
|
||||
if (i >= IC_KD_KH_KW) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t iic = i / KD_KH_KW;
|
||||
const uint32_t ikd = (i - iic * KD_KH_KW) / KH_KW;
|
||||
const uint32_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
|
||||
const uint32_t ikw = i % KW;
|
||||
|
||||
const uint32_t iow = gl_GlobalInvocationID.y;
|
||||
for (uint32_t iz = gl_GlobalInvocationID.z; iz < N_OD_OH; iz += gl_NumWorkGroups.z) {
|
||||
const uint32_t in_ = iz / OD_OH;
|
||||
const uint32_t iod = (iz - in_*OD_OH) / OH;
|
||||
const uint32_t ioh = iz % OH;
|
||||
|
||||
const uint32_t iiw = iow * s0 + ikw * d0 - p0;
|
||||
const uint32_t iih = ioh * s1 + ikh * d1 - p1;
|
||||
const uint32_t iid = iod * s2 + ikd * d2 - p2;
|
||||
|
||||
const uint32_t offset_dst = in_*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
|
||||
|
||||
if (iih >= IH || iiw >= IW || iid >= ID) {
|
||||
data_d[offset_dst + get_doffset()] = D_TYPE(0.0f);
|
||||
} else {
|
||||
const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;
|
||||
data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -713,6 +713,10 @@ void process_shaders() {
|
|||
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
|
||||
string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
|
||||
|
||||
string_to_spv("im2col_3d_f32", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("im2col_3d_f32_f16", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
|
||||
string_to_spv("im2col_3d_f32_f16_rte", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
|
||||
|
||||
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
|
|
|||
|
|
@ -300,6 +300,7 @@ static std::string var_to_str(ggml_scale_mode mode) {
|
|||
#define VARS_TO_STR13(a, b, c, d, e, f, g, h, i, j, k, l, m) VAR_TO_STR(a) + "," + VARS_TO_STR12(b, c, d, e, f, g, h, i, j, k, l, m)
|
||||
#define VARS_TO_STR14(a, b, c, d, e, f, g, h, i, j, k, l, m, n) VAR_TO_STR(a) + "," + VARS_TO_STR13(b, c, d, e, f, g, h, i, j, k, l, m, n)
|
||||
#define VARS_TO_STR15(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o) VAR_TO_STR(a) + "," + VARS_TO_STR14(b, c, d, e, f, g, h, i, j, k, l, m, n, o)
|
||||
#define VARS_TO_STR16(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p) VAR_TO_STR(a) + "," + VARS_TO_STR15(b, c, d, e, f, g, h, i, j, k, l, m, n, o, p)
|
||||
|
||||
#ifdef GGML_USE_SYCL
|
||||
static bool inline _isinf(float f) {
|
||||
|
|
@ -4047,9 +4048,10 @@ struct test_im2col_3d : public test_case {
|
|||
const int d2;
|
||||
|
||||
const int64_t IC;
|
||||
const bool v;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR15(type_input, type_kernel, dst_type, ne_input, ne_kernel, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2);
|
||||
return VARS_TO_STR16(type_input, type_kernel, dst_type, ne_input, ne_kernel, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, v);
|
||||
}
|
||||
|
||||
test_im2col_3d(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, ggml_type dst_type = GGML_TYPE_F32,
|
||||
|
|
@ -4058,14 +4060,20 @@ struct test_im2col_3d : public test_case {
|
|||
int64_t IC = 3,
|
||||
int s0 = 1, int s1 = 1, int s2 = 1,
|
||||
int p0 = 1, int p1 = 1, int p2 = 1,
|
||||
int d0 = 1, int d1 = 1, int d2 = 1)
|
||||
: type_input(type_input), type_kernel(type_kernel), dst_type(dst_type), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), IC(IC) {}
|
||||
int d0 = 1, int d1 = 1, int d2 = 1,
|
||||
bool v = false)
|
||||
: type_input(type_input), type_kernel(type_kernel), dst_type(dst_type), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), IC(IC), v(v) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
|
||||
ggml_set_param(input);
|
||||
ggml_set_name(input, "input");
|
||||
|
||||
if (v) {
|
||||
input = ggml_view_4d(ctx, input, ne_input[0] - 2, ne_input[1] - 2, ne_input[2] - 2, ne_input[3] - 2, input->nb[1], input->nb[2], input->nb[3], 0);
|
||||
ggml_set_name(input, "view_of_input");
|
||||
}
|
||||
|
||||
ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
|
||||
ggml_set_name(kernel, "kernel");
|
||||
|
||||
|
|
@ -5729,9 +5737,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
for (int d0 : {1, 3}) {
|
||||
for (int d1 : {1, 3}) {
|
||||
for (int d2 : {1, 3}) {
|
||||
for (int IC : {1, 3}) {
|
||||
for (bool v : {false, true}) {
|
||||
test_cases.emplace_back(new test_im2col_3d(
|
||||
GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 10, 3}, {3, 3, 3, 3},
|
||||
3, s0, s1, s2, p0, p1, p2, d0, d1, d2));
|
||||
IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, v));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue