pre-downsample position embeddings during GGUF conversion for fixed input size

This commit is contained in:
Anav Prasad 2026-02-13 11:49:38 +00:00
parent a565bbd1b4
commit 1bb128d3e6
2 changed files with 20 additions and 4 deletions

View File

@ -4107,8 +4107,24 @@ class NemotronNanoV2VLModel(MmprojModel):
return
# RADIO's pos_embed doesn't have .weight suffix, but clip.cpp expects it
if "patch_generator.pos_embed" in name and not name.endswith(".weight"):
name += ".weight"
if "patch_generator.pos_embed" in name:
if not name.endswith(".weight"):
name += ".weight"
# Downsample position embeddings for fixed 512x512 image size
import torch.nn.functional as F
n_embd = self.hparams["hidden_size"]
image_size = self.global_config.get("force_image_size", 512)
patch_size = self.hparams["patch_size"]
target_patches_per_side = image_size // patch_size # 32
max_patches_per_side = int((data_torch.shape[1]) ** 0.5) # 128
if target_patches_per_side != max_patches_per_side:
# Reshape to grid, interpolate, flatten back
data_torch = data_torch.reshape(1, max_patches_per_side, max_patches_per_side, n_embd)
data_torch = data_torch.permute(0, 3, 1, 2).float() # [1, n_embd, 128, 128]
data_torch = F.interpolate(data_torch, size=(target_patches_per_side, target_patches_per_side),
mode='bilinear', align_corners=True)
data_torch = data_torch.permute(0, 2, 3, 1) # [1, 32, 32, n_embd]
data_torch = data_torch.reshape(1, target_patches_per_side * target_patches_per_side, n_embd)
# Reshape linear patch embedding to conv2d format for ggml_conv_2d
# From [n_embd, patch_size*patch_size*3] to [n_embd, 3, patch_size, patch_size]

View File

@ -9,8 +9,8 @@ ggml_cgraph * clip_graph_nemotron_v2_vl::build() {
ggml_tensor * inp = build_inp();
ggml_tensor * pos_embd = resize_position_embeddings();
inp = ggml_add(ctx0, inp, pos_embd);
// add position embeddings (pre-downsampled during GGUF conversion for fixed 512x512 input)
inp = ggml_add(ctx0, inp, model.position_embeddings);
cb(inp, "inp_pos", -1);
inp = ggml_concat(ctx0, model.class_embedding, inp, 1);