model: LFM2-VL fixes (#17577)

* Adjust to pytorch

* Add antialiasing upscale

* Increase number of patches to 1024

* Handle default marker insertion for LFM2

* Switch to flag

* Reformat

* Cuda implementation of antialias kernel

* Change placement in ops.cpp

* consistent float literals

* Pad only for LFM2

* Address PR feedback

* Rollback default marker placement changes

* Fallback to CPU implementation for antialias implementation of upscale
This commit is contained in:
Tarek Dakhran 2025-11-30 21:57:31 +01:00 committed by GitHub
parent 7f8ef50cce
commit 2ba719519d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 162 additions and 13 deletions

View File

@ -2148,7 +2148,8 @@ extern "C" {
};
enum ggml_scale_flag {
GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8)
GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8),
GGML_SCALE_FLAG_ANTIALIAS = (1 << 9),
};
// interpolate

View File

@ -2500,6 +2500,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) {
return false;
}
if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) {
return false;
}
return true;
}
case GGML_OP_POOL_2D:

View File

@ -7420,6 +7420,65 @@ static void ggml_compute_forward_upscale_f32(
}
}
}
} else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) {
// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
auto triangle_filter = [](float x) -> float {
return std::max(1.0f - fabsf(x), 0.0f);
};
// support and invscale, minimum 1 pixel for bilinear
const float support1 = std::max(1.0f, 1.0f / sf1);
const float invscale1 = 1.0f / support1;
const float support0 = std::max(1.0f, 1.0f / sf0);
const float invscale0 = 1.0f / support0;
for (int64_t i3 = 0; i3 < ne3; i3++) {
const int64_t i03 = i3 / sf3;
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
const int64_t i02 = i2 / sf2;
for (int64_t i1 = 0; i1 < ne1; i1++) {
const float y = ((float) i1 + pixel_offset) / sf1;
for (int64_t i0 = 0; i0 < ne0; i0++) {
const float x = ((float) i0 + pixel_offset) / sf0;
// the range of source pixels that contribute
const int64_t x_min = std::max<int64_t>(x - support0 + pixel_offset, 0);
const int64_t x_max = std::min<int64_t>(x + support0 + pixel_offset, ne00);
const int64_t y_min = std::max<int64_t>(y - support1 + pixel_offset, 0);
const int64_t y_max = std::min<int64_t>(y + support1 + pixel_offset, ne01);
// bilinear filter with antialiasing
float val = 0.0f;
float total_weight = 0.0f;
for (int64_t sy = y_min; sy < y_max; sy++) {
const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
for (int64_t sx = x_min; sx < x_max; sx++) {
const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
const float weight = weight_x * weight_y;
if (weight <= 0.0f) {
continue;
}
const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03);
val += pixel * weight;
total_weight += weight;
}
}
if (total_weight > 0.0f) {
val /= total_weight;
}
float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
*dst_ptr = val;
}
}
}
}
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
for (int64_t i3 = 0; i3 < ne3; i3++) {
const int64_t i03 = i3 / sf3;

View File

@ -81,6 +81,76 @@ static __global__ void upscale_f32_bilinear(const float * x, float * dst,
dst[index] = result;
}
// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
static __global__ void upscale_f32_bilinear_antialias(const float * src0, float * dst,
const int nb00, const int nb01, const int nb02, const int nb03,
const int ne00_src, const int ne01_src,
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
const float sf0, const float sf1, const float sf2, const float sf3,
const float pixel_offset) {
const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
if (index >= dst_total_elements) {
return;
}
const int i10_dst = index % ne10_dst;
const int i11_dst = (index / ne10_dst) % ne11_dst;
const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
const int i02_src = (int)(i12_dst / sf2);
const int i03_src = (int)(i13_dst / sf3);
const float y = ((float)i11_dst + pixel_offset) / sf1;
const float x = ((float)i10_dst + pixel_offset) / sf0;
// support and invscale, minimum 1 pixel for bilinear
const float support1 = max(1.0f / sf1, 1.0f);
const float invscale1 = 1.0f / support1;
const float support0 = max(1.0f / sf0, 1.0f);
const float invscale0 = 1.0f / support0;
// the range of source pixels that contribute
const int64_t x_min = max(int64_t(0), int64_t(x - support0 + pixel_offset));
const int64_t x_max = min(int64_t(ne00_src), int64_t(x + support0 + pixel_offset));
const int64_t y_min = max(int64_t(0), int64_t(y - support1 + pixel_offset));
const int64_t y_max = min(int64_t(ne01_src), int64_t(y + support1 + pixel_offset));
// bilinear filter with antialiasing
float val = 0.0f;
float total_weight = 0.0f;
auto triangle_filter = [](float x) -> float {
return max(1.0f - fabsf(x), 0.0f);
};
for (int64_t sy = y_min; sy < y_max; sy++) {
const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
for (int64_t sx = x_min; sx < x_max; sx++) {
const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
const float weight = weight_x * weight_y;
if (weight <= 0.0f) {
continue;
}
const float pixel = *(const float *)((const char *)src0 + sx*nb00 + sy*nb01 + i02_src*nb02 + i03_src*nb03);
val += pixel * weight;
total_weight += weight;
}
}
if (total_weight > 0.0f) {
val /= total_weight;
}
dst[index] = val;
}
namespace bicubic_interpolation {
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
__device__ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
@ -161,11 +231,15 @@ static void upscale_f32_bilinear_cuda(const float * x, float * dst,
const int ne00_src, const int ne01_src,
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
const float sf0, const float sf1, const float sf2, const float sf3,
const float pixel_offset, cudaStream_t stream) {
const float pixel_offset, bool antialias, cudaStream_t stream) {
const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
if (antialias) {
upscale_f32_bilinear_antialias<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
} else {
upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
}
}
static void upscale_f32_bicubic_cuda(const float * x, float * dst,
@ -207,9 +281,10 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (mode == GGML_SCALE_MODE_NEAREST) {
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);
upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
sf0, sf1, sf2, sf3, pixel_offset, stream);
sf0, sf1, sf2, sf3, pixel_offset, antialias, stream);
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],

View File

@ -894,7 +894,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_POOL_1D:
return false;
case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
case GGML_OP_POOL_2D:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_PAD:

View File

@ -3086,8 +3086,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_UPSCALE: {
ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF);
const bool antialias = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & GGML_SCALE_FLAG_ANTIALIAS);
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
(mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR);
(mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR) && !antialias;
}
case GGML_OP_CONV_2D:
return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||

View File

@ -4597,7 +4597,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_IM2COL:
return true;
case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:

View File

@ -14113,6 +14113,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
}
return true;
case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
case GGML_OP_ACC:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_CONCAT:

View File

@ -4891,6 +4891,8 @@ static struct ggml_tensor * ggml_interpolate_impl(
int64_t ne3,
uint32_t mode) {
GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT);
// TODO: implement antialias for modes other than bilinear
GGML_ASSERT(!(mode & GGML_SCALE_FLAG_ANTIALIAS) || (mode & 0xFF) == GGML_SCALE_MODE_BILINEAR);
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);

View File

@ -7660,7 +7660,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
// test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {i, 2, 1, 3}, rand() % i + 1));
//}
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) {
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC, ggml_scale_mode(GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS)}) {
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode));

View File

@ -2020,7 +2020,7 @@ private:
ggml_tensor * pos_embd = model.position_embeddings;
const int height = img.ny / patch_size;
const int width = img.nx / patch_size;
const uint32_t mode = GGML_SCALE_MODE_BILINEAR;
const uint32_t mode = GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS;
const int n_per_side = (int)std::sqrt(pos_embd->ne[1]);
GGML_ASSERT(pos_embd);
@ -2795,7 +2795,8 @@ struct clip_model_loader {
{
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
// ref: https://huggingface.co/LiquidAI/LFM2-VL-3B/blob/main/preprocessor_config.json
hparams.set_limit_image_tokens(64, 256);
// config above specifies number of tokens after downsampling, while here it is before, relax lowerbound to 64
hparams.set_limit_image_tokens(64, 1024);
} break;
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_LIGHTONOCR:
@ -3745,12 +3746,13 @@ struct img_tool {
const int width = inp_size.width;
const int height = inp_size.height;
auto round_by_factor = [f = align_size](float x) { return static_cast<int>(std::round(x / static_cast<float>(f))) * f; };
auto ceil_by_factor = [f = align_size](float x) { return static_cast<int>(std::ceil(x / static_cast<float>(f))) * f; };
auto floor_by_factor = [f = align_size](float x) { return static_cast<int>(std::floor(x / static_cast<float>(f))) * f; };
// always align up first
int h_bar = std::max(align_size, ceil_by_factor(height));
int w_bar = std::max(align_size, ceil_by_factor(width));
int h_bar = std::max(align_size, round_by_factor(height));
int w_bar = std::max(align_size, round_by_factor(width));
if (h_bar * w_bar > max_pixels) {
const auto beta = std::sqrt(static_cast<float>(height * width) / max_pixels);
@ -4365,7 +4367,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
const std::array<uint8_t, 3> pad_color = {122, 116, 104};
clip_image_u8 resized_img;
img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, true, pad_color);
const bool pad = (ctx->proj_type() != PROJECTOR_TYPE_LFM2);
img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, pad, pad_color);
clip_image_f32_ptr res(clip_image_f32_init());
normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std);
res_imgs->entries.push_back(std::move(res));

View File

@ -304,6 +304,10 @@ struct mtmd_context {
img_beg = "<|im_start|>";
img_end = "<|im_end|>";
} else if (proj == PROJECTOR_TYPE_LFM2) {
img_beg = "<|image_start|>";
img_end = "<|image_end|>";
}
}