cuda/vulkan : bicubic interpolation (#17022)
* vulkan : implement upscale with bicubic interpolation * cuda : implement upscale with bicubic interpolation * tests : add ggml_interpolate with GGML_SCALE_MODE_BICUBIC to backend tests * adapt OpenCL backend to not support the OP in that case so tests don't fail * print scale mode & flags in test-backend-ops
This commit is contained in:
parent
15274c0c50
commit
1032256ec9
|
|
@ -81,6 +81,70 @@ static __global__ void upscale_f32_bilinear(const float * x, float * dst,
|
|||
dst[index] = result;
|
||||
}
|
||||
|
||||
namespace bicubic_interpolation {
|
||||
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
|
||||
__device__ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
|
||||
|
||||
static __device__ float weight1(float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
|
||||
static __device__ float weight2(float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
|
||||
|
||||
static __device__ float bicubic(float p0, float p1, float p2, float p3, float x) {
|
||||
const float w0 = weight2(x + 1);
|
||||
const float w1 = weight1(x + 0);
|
||||
const float w2 = weight1(1 - x);
|
||||
const float w3 = weight2(2 - x);
|
||||
return p0 * w0 + p1 * w1 + p2 * w2 + p3 * w3;
|
||||
};
|
||||
} // namespace bicubic_interpolation
|
||||
|
||||
static __global__ void upscale_f32_bicubic(const float * x, float * dst,
|
||||
const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne00_src, const int ne01_src,
|
||||
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
|
||||
const float sf0, const float sf1, const float sf2, const float sf3,
|
||||
const float pixel_offset) {
|
||||
using bicubic_interpolation::bicubic;
|
||||
|
||||
const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
||||
|
||||
if (index >= dst_total_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i10_dst = index % ne10_dst;
|
||||
const int i11_dst = (index / ne10_dst) % ne11_dst;
|
||||
const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
|
||||
const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
|
||||
|
||||
const int i02_src = (int)(i12_dst / sf2);
|
||||
const int i03_src = (int)(i13_dst / sf3);
|
||||
|
||||
const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
|
||||
const int y0_src = (int)floorf(y_src_f);
|
||||
const float dy = y_src_f - (float)y0_src;
|
||||
|
||||
const float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
|
||||
const int x0_src = (int)floorf(x_src_f);
|
||||
const float dx = x_src_f - (float)x0_src;
|
||||
|
||||
const char * x_base = (const char *)x + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03;
|
||||
|
||||
auto load = [=](int x_off, int y_off) -> float {
|
||||
int i00_src = max(0, min(x0_src + x_off, ne00_src - 1));
|
||||
int i01_src = max(0, min(y0_src + y_off, ne01_src - 1));
|
||||
return *(const float *)(x_base + (int64_t)i00_src * nb00 + (int64_t)i01_src * nb01);
|
||||
};
|
||||
|
||||
const float result = bicubic(
|
||||
bicubic(load(-1,-1), load(0,-1), load(1,-1), load(2,-1), dx),
|
||||
bicubic(load(-1, 0), load(0, 0), load(1, 0), load(2, 0), dx),
|
||||
bicubic(load(-1, 1), load(0, 1), load(1, 1), load(2, 1), dx),
|
||||
bicubic(load(-1, 2), load(0, 2), load(1, 2), load(2, 2), dx), dy);
|
||||
|
||||
dst[index] = result;
|
||||
}
|
||||
|
||||
static void upscale_f32_cuda(const float * x, float * dst,
|
||||
const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne10, const int ne11, const int ne12, const int ne13,
|
||||
|
|
@ -104,6 +168,18 @@ static void upscale_f32_bilinear_cuda(const float * x, float * dst,
|
|||
upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
||||
}
|
||||
|
||||
static void upscale_f32_bicubic_cuda(const float * x, float * dst,
|
||||
const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne00_src, const int ne01_src,
|
||||
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
|
||||
const float sf0, const float sf1, const float sf2, const float sf3,
|
||||
const float pixel_offset, cudaStream_t stream) {
|
||||
const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
||||
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
|
||||
|
||||
upscale_f32_bicubic<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
|
|
@ -121,17 +197,22 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
float sf2 = (float)dst->ne[2]/src0->ne[2];
|
||||
const float sf3 = (float)dst->ne[3]/src0->ne[3];
|
||||
|
||||
float pixel_offset = 0.5f;
|
||||
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
||||
sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) : sf0;
|
||||
sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) : sf1;
|
||||
pixel_offset = 0.0f;
|
||||
}
|
||||
|
||||
if (mode == GGML_SCALE_MODE_NEAREST) {
|
||||
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
|
||||
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
||||
float pixel_offset = 0.5f;
|
||||
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
||||
sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) : sf0;
|
||||
sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) : sf1;
|
||||
pixel_offset = 0.0f;
|
||||
}
|
||||
upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
sf0, sf1, sf2, sf3, pixel_offset, stream);
|
||||
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
|
||||
upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
sf0, sf1, sf2, sf3, pixel_offset, stream);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2944,8 +2944,11 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
|||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
|
||||
case GGML_OP_PAD:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_UPSCALE: {
|
||||
ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF);
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
|
||||
(mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR);
|
||||
}
|
||||
case GGML_OP_CONV_2D:
|
||||
return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
|
||||
(op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
|
||||
|
|
|
|||
|
|
@ -620,7 +620,7 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_add_id_f32;
|
||||
|
||||
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
|
||||
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32;
|
||||
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32;
|
||||
vk_pipeline pipeline_scale_f32;
|
||||
vk_pipeline pipeline_sqr_f32;
|
||||
vk_pipeline pipeline_sqrt_f32;
|
||||
|
|
@ -3702,6 +3702,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_upscale_bicubic_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BICUBIC}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
|
|
@ -8200,6 +8201,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return ctx->device->pipeline_upscale_nearest_f32;
|
||||
case GGML_SCALE_MODE_BILINEAR:
|
||||
return ctx->device->pipeline_upscale_bilinear_f32;
|
||||
case GGML_SCALE_MODE_BICUBIC:
|
||||
return ctx->device->pipeline_upscale_bicubic_f32;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
|||
// from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag
|
||||
#define NEAREST 0
|
||||
#define BILINEAR 1
|
||||
#define BICUBIC 2
|
||||
|
||||
layout (constant_id = 0) const uint scale_mode = 0;
|
||||
|
||||
|
|
@ -61,6 +62,39 @@ float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
|
|||
return fetch_bilinear(c0, c1, d, i12, i13);
|
||||
}
|
||||
|
||||
// Bicubic interpolation with alpha = -0.75
|
||||
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
|
||||
const vec4 bcoeffs1 = vec4( 1.25, -2.25, 0.0, 1.0);
|
||||
const vec4 bcoeffs2 = vec4(-0.75, 3.75, -6.0, 3.0);
|
||||
vec4 powers(float x) { return vec4(x*x*x, x*x, x, 1); }
|
||||
|
||||
float bicubic(float p0, float p1, float p2, float p3, float x) {
|
||||
return p0 * dot(bcoeffs2, powers(x + 1)) +
|
||||
p1 * dot(bcoeffs1, powers(x )) +
|
||||
p2 * dot(bcoeffs1, powers(1 - x)) +
|
||||
p3 * dot(bcoeffs2, powers(2 - x));
|
||||
}
|
||||
|
||||
#define FETCH(a,b) data_a[base + clamp(i.x+(a), 0, res.x) * p.nb00 + clamp(i.y+(b), 0, res.y) * p.nb01]
|
||||
|
||||
float interpolate_bicubic(uint i10, uint i11, uint i12, uint i13) {
|
||||
const ivec2 res = ivec2(p.ne00 - 1, p.ne01 - 1);
|
||||
|
||||
const vec2 coord = (vec2(i10, i11) + p.pixel_offset) / vec2(p.sf0, p.sf1) - p.pixel_offset;
|
||||
const vec2 d = fract(coord);
|
||||
const ivec2 i = ivec2(floor(coord));
|
||||
|
||||
const uint i02 = uint(i12 / p.sf2);
|
||||
const uint i03 = uint(i13 / p.sf3);
|
||||
const uint base = p.a_offset + i03 * p.nb03 + i02 * p.nb02;
|
||||
|
||||
return bicubic(
|
||||
bicubic(FETCH(-1,-1), FETCH(0,-1), FETCH(1,-1), FETCH(2,-1), d.x),
|
||||
bicubic(FETCH(-1, 0), FETCH(0, 0), FETCH(1, 0), FETCH(2, 0), d.x),
|
||||
bicubic(FETCH(-1, 1), FETCH(0, 1), FETCH(1, 1), FETCH(2, 1), d.x),
|
||||
bicubic(FETCH(-1, 2), FETCH(0, 2), FETCH(1, 2), FETCH(2, 2), d.x), d.y);
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
|
|
@ -81,6 +115,9 @@ void main() {
|
|||
case BILINEAR:
|
||||
result = interpolate_bilinear(i10, i11, i12, i13);
|
||||
break;
|
||||
case BICUBIC:
|
||||
result = interpolate_bicubic(i10, i11, i12, i13);
|
||||
break;
|
||||
}
|
||||
|
||||
data_d[p.d_offset + idx] = D_TYPE(result);
|
||||
|
|
|
|||
|
|
@ -272,6 +272,10 @@ static double mean_abs_asymm(const float * a, const float * b, const size_t n, c
|
|||
|
||||
// utils for printing the variables of the test cases
|
||||
|
||||
static std::string var_to_str(const std::string & x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static std::string var_to_str(const T & x) {
|
||||
return std::to_string(x);
|
||||
|
|
@ -323,7 +327,8 @@ static std::string var_to_str(ggml_scale_mode mode) {
|
|||
switch (mode) {
|
||||
case GGML_SCALE_MODE_NEAREST: return "nearest";
|
||||
case GGML_SCALE_MODE_BILINEAR: return "bilinear";
|
||||
default: return std::to_string(mode);
|
||||
case GGML_SCALE_MODE_BICUBIC: return "bicubic";
|
||||
default: return std::to_string(mode);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -5218,7 +5223,9 @@ struct test_interpolate : public test_case {
|
|||
const uint32_t mode = GGML_SCALE_MODE_NEAREST;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne, ne_tgt, mode);
|
||||
ggml_scale_mode mode = (ggml_scale_mode)(this->mode & 0xFF);
|
||||
std::string flags = (this->mode & GGML_SCALE_FLAG_ALIGN_CORNERS) ? "align_corners" : "none";
|
||||
return VARS_TO_STR5(type, ne, ne_tgt, mode, flags);
|
||||
}
|
||||
|
||||
test_interpolate(ggml_type type = GGML_TYPE_F32,
|
||||
|
|
@ -7288,15 +7295,17 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
|
||||
}
|
||||
|
||||
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {
|
||||
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) {
|
||||
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
|
||||
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
|
||||
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode));
|
||||
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {5, 7, 11, 13}, {2, 5, 7, 11}, mode));
|
||||
}
|
||||
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS));
|
||||
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {1, 4, 3, 2}, {2, 8, 3, 2}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS));
|
||||
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {4, 1, 3, 2}, {1, 1, 3, 2}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS));
|
||||
for (ggml_scale_mode mode : {GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) {
|
||||
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode | GGML_SCALE_FLAG_ALIGN_CORNERS));
|
||||
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {1, 4, 3, 2}, {2, 8, 3, 2}, mode | GGML_SCALE_FLAG_ALIGN_CORNERS));
|
||||
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {4, 1, 3, 2}, {1, 1, 3, 2}, mode | GGML_SCALE_FLAG_ALIGN_CORNERS));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_sum());
|
||||
test_cases.emplace_back(new test_sum_rows());
|
||||
|
|
|
|||
Loading…
Reference in New Issue