cuda/vulkan : bicubic interpolation (#17022)

* vulkan : implement upscale with bicubic interpolation

* cuda : implement upscale with bicubic interpolation

* tests : add ggml_interpolate with GGML_SCALE_MODE_BICUBIC to backend tests

* adapt OpenCL backend to not support the OP in that case so tests don't fail

* print scale mode & flags in test-backend-ops
This commit is contained in:
Acly 2025-11-10 10:19:39 +01:00 committed by GitHub
parent 15274c0c50
commit 1032256ec9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 148 additions and 15 deletions

View File

@ -81,6 +81,70 @@ static __global__ void upscale_f32_bilinear(const float * x, float * dst,
dst[index] = result; dst[index] = result;
} }
namespace bicubic_interpolation {
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
__device__ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
static __device__ float weight1(float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
static __device__ float weight2(float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
static __device__ float bicubic(float p0, float p1, float p2, float p3, float x) {
const float w0 = weight2(x + 1);
const float w1 = weight1(x + 0);
const float w2 = weight1(1 - x);
const float w3 = weight2(2 - x);
return p0 * w0 + p1 * w1 + p2 * w2 + p3 * w3;
};
} // namespace bicubic_interpolation
static __global__ void upscale_f32_bicubic(const float * x, float * dst,
const int nb00, const int nb01, const int nb02, const int nb03,
const int ne00_src, const int ne01_src,
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
const float sf0, const float sf1, const float sf2, const float sf3,
const float pixel_offset) {
using bicubic_interpolation::bicubic;
const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
if (index >= dst_total_elements) {
return;
}
const int i10_dst = index % ne10_dst;
const int i11_dst = (index / ne10_dst) % ne11_dst;
const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
const int i02_src = (int)(i12_dst / sf2);
const int i03_src = (int)(i13_dst / sf3);
const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
const int y0_src = (int)floorf(y_src_f);
const float dy = y_src_f - (float)y0_src;
const float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
const int x0_src = (int)floorf(x_src_f);
const float dx = x_src_f - (float)x0_src;
const char * x_base = (const char *)x + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03;
auto load = [=](int x_off, int y_off) -> float {
int i00_src = max(0, min(x0_src + x_off, ne00_src - 1));
int i01_src = max(0, min(y0_src + y_off, ne01_src - 1));
return *(const float *)(x_base + (int64_t)i00_src * nb00 + (int64_t)i01_src * nb01);
};
const float result = bicubic(
bicubic(load(-1,-1), load(0,-1), load(1,-1), load(2,-1), dx),
bicubic(load(-1, 0), load(0, 0), load(1, 0), load(2, 0), dx),
bicubic(load(-1, 1), load(0, 1), load(1, 1), load(2, 1), dx),
bicubic(load(-1, 2), load(0, 2), load(1, 2), load(2, 2), dx), dy);
dst[index] = result;
}
static void upscale_f32_cuda(const float * x, float * dst, static void upscale_f32_cuda(const float * x, float * dst,
const int nb00, const int nb01, const int nb02, const int nb03, const int nb00, const int nb01, const int nb02, const int nb03,
const int ne10, const int ne11, const int ne12, const int ne13, const int ne10, const int ne11, const int ne12, const int ne13,
@ -104,6 +168,18 @@ static void upscale_f32_bilinear_cuda(const float * x, float * dst,
upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
} }
static void upscale_f32_bicubic_cuda(const float * x, float * dst,
const int nb00, const int nb01, const int nb02, const int nb03,
const int ne00_src, const int ne01_src,
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
const float sf0, const float sf1, const float sf2, const float sf3,
const float pixel_offset, cudaStream_t stream) {
const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
upscale_f32_bicubic<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
}
void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data; const float * src0_d = (const float *)src0->data;
@ -121,17 +197,22 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float sf2 = (float)dst->ne[2]/src0->ne[2]; float sf2 = (float)dst->ne[2]/src0->ne[2];
const float sf3 = (float)dst->ne[3]/src0->ne[3]; const float sf3 = (float)dst->ne[3]/src0->ne[3];
float pixel_offset = 0.5f;
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) : sf0;
sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) : sf1;
pixel_offset = 0.0f;
}
if (mode == GGML_SCALE_MODE_NEAREST) { if (mode == GGML_SCALE_MODE_NEAREST) {
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream); upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
} else if (mode == GGML_SCALE_MODE_BILINEAR) { } else if (mode == GGML_SCALE_MODE_BILINEAR) {
float pixel_offset = 0.5f;
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) : sf0;
sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) : sf1;
pixel_offset = 0.0f;
}
upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
sf0, sf1, sf2, sf3, pixel_offset, stream); sf0, sf1, sf2, sf3, pixel_offset, stream);
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
sf0, sf1, sf2, sf3, pixel_offset, stream);
} }
} }

View File

@ -2944,8 +2944,11 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
case GGML_OP_PAD: case GGML_OP_PAD:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE: {
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF);
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
(mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR);
}
case GGML_OP_CONV_2D: case GGML_OP_CONV_2D:
return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) || return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
(op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||

View File

@ -620,7 +620,7 @@ struct vk_device_struct {
vk_pipeline pipeline_add_id_f32; vk_pipeline pipeline_add_id_f32;
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32; vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32;
vk_pipeline pipeline_scale_f32; vk_pipeline pipeline_scale_f32;
vk_pipeline pipeline_sqr_f32; vk_pipeline pipeline_sqr_f32;
vk_pipeline pipeline_sqrt_f32; vk_pipeline pipeline_sqrt_f32;
@ -3702,6 +3702,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1); ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1); ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
ggml_vk_create_pipeline(device, device->pipeline_upscale_bicubic_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BICUBIC}, 1);
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@ -8200,6 +8201,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_upscale_nearest_f32; return ctx->device->pipeline_upscale_nearest_f32;
case GGML_SCALE_MODE_BILINEAR: case GGML_SCALE_MODE_BILINEAR:
return ctx->device->pipeline_upscale_bilinear_f32; return ctx->device->pipeline_upscale_bilinear_f32;
case GGML_SCALE_MODE_BICUBIC:
return ctx->device->pipeline_upscale_bicubic_f32;
default: default:
return nullptr; return nullptr;
} }

View File

@ -20,6 +20,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
// from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag // from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag
#define NEAREST 0 #define NEAREST 0
#define BILINEAR 1 #define BILINEAR 1
#define BICUBIC 2
layout (constant_id = 0) const uint scale_mode = 0; layout (constant_id = 0) const uint scale_mode = 0;
@ -61,6 +62,39 @@ float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
return fetch_bilinear(c0, c1, d, i12, i13); return fetch_bilinear(c0, c1, d, i12, i13);
} }
// Bicubic interpolation with alpha = -0.75
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
const vec4 bcoeffs1 = vec4( 1.25, -2.25, 0.0, 1.0);
const vec4 bcoeffs2 = vec4(-0.75, 3.75, -6.0, 3.0);
vec4 powers(float x) { return vec4(x*x*x, x*x, x, 1); }
float bicubic(float p0, float p1, float p2, float p3, float x) {
return p0 * dot(bcoeffs2, powers(x + 1)) +
p1 * dot(bcoeffs1, powers(x )) +
p2 * dot(bcoeffs1, powers(1 - x)) +
p3 * dot(bcoeffs2, powers(2 - x));
}
#define FETCH(a,b) data_a[base + clamp(i.x+(a), 0, res.x) * p.nb00 + clamp(i.y+(b), 0, res.y) * p.nb01]
float interpolate_bicubic(uint i10, uint i11, uint i12, uint i13) {
const ivec2 res = ivec2(p.ne00 - 1, p.ne01 - 1);
const vec2 coord = (vec2(i10, i11) + p.pixel_offset) / vec2(p.sf0, p.sf1) - p.pixel_offset;
const vec2 d = fract(coord);
const ivec2 i = ivec2(floor(coord));
const uint i02 = uint(i12 / p.sf2);
const uint i03 = uint(i13 / p.sf3);
const uint base = p.a_offset + i03 * p.nb03 + i02 * p.nb02;
return bicubic(
bicubic(FETCH(-1,-1), FETCH(0,-1), FETCH(1,-1), FETCH(2,-1), d.x),
bicubic(FETCH(-1, 0), FETCH(0, 0), FETCH(1, 0), FETCH(2, 0), d.x),
bicubic(FETCH(-1, 1), FETCH(0, 1), FETCH(1, 1), FETCH(2, 1), d.x),
bicubic(FETCH(-1, 2), FETCH(0, 2), FETCH(1, 2), FETCH(2, 2), d.x), d.y);
}
void main() { void main() {
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
@ -81,6 +115,9 @@ void main() {
case BILINEAR: case BILINEAR:
result = interpolate_bilinear(i10, i11, i12, i13); result = interpolate_bilinear(i10, i11, i12, i13);
break; break;
case BICUBIC:
result = interpolate_bicubic(i10, i11, i12, i13);
break;
} }
data_d[p.d_offset + idx] = D_TYPE(result); data_d[p.d_offset + idx] = D_TYPE(result);

View File

@ -272,6 +272,10 @@ static double mean_abs_asymm(const float * a, const float * b, const size_t n, c
// utils for printing the variables of the test cases // utils for printing the variables of the test cases
static std::string var_to_str(const std::string & x) {
return x;
}
template<typename T> template<typename T>
static std::string var_to_str(const T & x) { static std::string var_to_str(const T & x) {
return std::to_string(x); return std::to_string(x);
@ -323,7 +327,8 @@ static std::string var_to_str(ggml_scale_mode mode) {
switch (mode) { switch (mode) {
case GGML_SCALE_MODE_NEAREST: return "nearest"; case GGML_SCALE_MODE_NEAREST: return "nearest";
case GGML_SCALE_MODE_BILINEAR: return "bilinear"; case GGML_SCALE_MODE_BILINEAR: return "bilinear";
default: return std::to_string(mode); case GGML_SCALE_MODE_BICUBIC: return "bicubic";
default: return std::to_string(mode);
} }
} }
@ -5218,7 +5223,9 @@ struct test_interpolate : public test_case {
const uint32_t mode = GGML_SCALE_MODE_NEAREST; const uint32_t mode = GGML_SCALE_MODE_NEAREST;
std::string vars() override { std::string vars() override {
return VARS_TO_STR4(type, ne, ne_tgt, mode); ggml_scale_mode mode = (ggml_scale_mode)(this->mode & 0xFF);
std::string flags = (this->mode & GGML_SCALE_FLAG_ALIGN_CORNERS) ? "align_corners" : "none";
return VARS_TO_STR5(type, ne, ne_tgt, mode, flags);
} }
test_interpolate(ggml_type type = GGML_TYPE_F32, test_interpolate(ggml_type type = GGML_TYPE_F32,
@ -7288,15 +7295,17 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection) test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
} }
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) { for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) {
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode)); test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true)); test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode)); test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode));
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {5, 7, 11, 13}, {2, 5, 7, 11}, mode)); test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {5, 7, 11, 13}, {2, 5, 7, 11}, mode));
} }
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS)); for (ggml_scale_mode mode : {GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) {
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {1, 4, 3, 2}, {2, 8, 3, 2}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS)); test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode | GGML_SCALE_FLAG_ALIGN_CORNERS));
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {4, 1, 3, 2}, {1, 1, 3, 2}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS)); test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {1, 4, 3, 2}, {2, 8, 3, 2}, mode | GGML_SCALE_FLAG_ALIGN_CORNERS));
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {4, 1, 3, 2}, {1, 1, 3, 2}, mode | GGML_SCALE_FLAG_ALIGN_CORNERS));
}
test_cases.emplace_back(new test_sum()); test_cases.emplace_back(new test_sum());
test_cases.emplace_back(new test_sum_rows()); test_cases.emplace_back(new test_sum_rows());