Merge pull request #1 from bluebread/sf/deepseek-ocr

mtmd: fix vision model processing
This commit is contained in:
Saba Fallah 2025-11-15 11:51:21 +01:00 committed by GitHub
commit 578c8d77dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 112 additions and 48 deletions

View File

@ -1445,7 +1445,7 @@ class MmprojModel(ModelBase):
preprocessor_config: dict[str, Any] preprocessor_config: dict[str, Any]
global_config: dict[str, Any] global_config: dict[str, Any]
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "width.clip-l-14-224.layers", "sam_vit_b.layers"] n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "layers"]
has_vision_encoder: bool = True # by default has_vision_encoder: bool = True # by default
has_audio_encoder: bool = False has_audio_encoder: bool = False
@ -1494,8 +1494,8 @@ class MmprojModel(ModelBase):
# FIXME: DeepseekOCRVisionModel specific hack # FIXME: DeepseekOCRVisionModel specific hack
if self.block_count is None: if self.block_count is None:
if isinstance(self, DeepseekOCRVisionModel): if isinstance(self, DeepseekOCRVisionModel):
clip_block_count = self.hparams['width']['clip-l-14-224']['layers'] print(self.hparams)
sam_block_count = self.hparams['width']['sam_vit_b']['layers'] clip_block_count = self.hparams['layers']
if clip_block_count is not None: if clip_block_count is not None:
self.block_count = clip_block_count self.block_count = clip_block_count
if sam_block_count is not None: if sam_block_count is not None:
@ -5793,6 +5793,16 @@ class Gemma3VisionModel(MmprojModel):
@ModelBase.register("DeepseekOCRForCausalLM") @ModelBase.register("DeepseekOCRForCausalLM")
class DeepseekOCRVisionModel(MmprojModel): class DeepseekOCRVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
proc_fname = self.dir_model / "processor_config.json"
if proc_fname.is_file():
with open(proc_fname, "r") as f:
self.preprocessor_config = json.load(f)
def set_gguf_parameters(self): def set_gguf_parameters(self):
super().set_gguf_parameters() super().set_gguf_parameters()
hparams = self.hparams hparams = self.hparams
@ -5811,10 +5821,25 @@ class DeepseekOCRVisionModel(MmprojModel):
# in this case, we are converting a test model # in this case, we are converting a test model
self.gguf_writer.add_vision_projector_scale_factor(proj_scale_factor) self.gguf_writer.add_vision_projector_scale_factor(proj_scale_factor)
def get_vision_config(self) -> dict[str, Any]: # SAM configuration
orig_vision_config = self.global_config.get("vision_config") sam_hparams = hparams['sam']
self.gguf_writer.add_vision_sam_layers_count(sam_hparams['layers'])
self.gguf_writer.add_vision_sam_embedding_length(sam_hparams['width'])
def get_vision_config(self) -> dict[str, Any]:
vision_config: dict[str, Any] | None = self.global_config.get("vision_config")
if not vision_config:
raise ValueError("DeepseekOCR model requires 'vision_config' in the model configuration, but it was not found")
vision_config['sam'] = vision_config['width']['sam_vit_b']
vision_config.update(vision_config['width']['clip-l-14-224'])
vision_config['hidden_size'] = vision_config['width']
vision_config['num_heads'] = vision_config['heads']
vision_config['intermediate_size'] = vision_config['heads'] * 4
return vision_config
super().get_vision_config()
def tensor_force_quant(self, name, new_name, bid, n_dims): def tensor_force_quant(self, name, new_name, bid, n_dims):
# related to https://github.com/ggml-org/llama.cpp/issues/13025 # related to https://github.com/ggml-org/llama.cpp/issues/13025
@ -5825,26 +5850,16 @@ class DeepseekOCRVisionModel(MmprojModel):
return super().tensor_force_quant(name, new_name, bid, n_dims) return super().tensor_force_quant(name, new_name, bid, n_dims)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused # Only process vision-related tensors, skip language model tensors
# Vision components: sam_model, vision_model, projector, image_newline, view_seperator
# Language model components to skip: lm_head, embed_tokens, layers, norm
if name.startswith(("lm_head.", "model.embed_tokens.", "model.layers.", "model.norm.")):
return []
if "vision_model.head." in name: if ".attn.rel_pos_h" in name or ".attn.rel_pos_w" in name:
return [] # skip redundant tensors for tinygemma3 return [(self.map_tensor_name(name, try_suffixes=("",)), data_torch)]
if name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \ return [(self.map_tensor_name(name), data_torch)]
or name.startswith("multimodal_projector.") or name.startswith("vision_model."):
# process vision tensors
name = name.replace("_weight", ".weight")
# correct norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector
# the other norm values are part of SigLIP model, and they are already correct
# ref code: Gemma3RMSNorm
if "soft_emb_norm.weight" in name:
logger.info(f"Correcting norm value for '{name}'")
data_torch = data_torch + 1
return [(self.map_tensor_name(name), data_torch)]
return [] # skip other tensors
@ModelBase.register("Gemma3nForConditionalGeneration") @ModelBase.register("Gemma3nForConditionalGeneration")

View File

@ -287,6 +287,10 @@ class Keys:
class Projector: class Projector:
SCALE_FACTOR = "clip.vision.projector.scale_factor" SCALE_FACTOR = "clip.vision.projector.scale_factor"
class SAM:
BLOCK_COUNT = "clip.vision.sam.block_count"
EMBEDDING_LENGTH = "clip.vision.sam.embedding_length"
class ClipAudio: class ClipAudio:
NUM_MEL_BINS = "clip.audio.num_mel_bins" NUM_MEL_BINS = "clip.audio.num_mel_bins"
EMBEDDING_LENGTH = "clip.audio.embedding_length" EMBEDDING_LENGTH = "clip.audio.embedding_length"
@ -664,20 +668,21 @@ class MODEL_TENSOR(IntEnum):
V_MM_GATE = auto() # cogvlm V_MM_GATE = auto() # cogvlm
V_TOK_BOI = auto() # cogvlm V_TOK_BOI = auto() # cogvlm
V_TOK_EOI = auto() # cogvlm V_TOK_EOI = auto() # cogvlm
# DeepSeek-OCR sam_model V_SAM_POS_EMBD = auto() # Deepseek-OCR
V_SAM_POS_EMBD = auto() V_SAM_PATCH_EMBD = auto() # Deepseek-OCR
V_SAM_PATCH_EMBD = auto() V_SAM_PRE_NORM = auto() # Deepseek-OCR
V_SAM_PRE_NORM = auto() V_SAM_POST_NORM = auto() # Deepseek-OCR
V_SAM_POST_NORM = auto() V_SAM_ATTN_POS_H = auto() # Deepseek-OCR
V_SAM_ATTN_POS_H = auto() V_SAM_ATTN_POS_W = auto() # Deepseek-OCR
V_SAM_ATTN_POS_W = auto() V_SAM_ATTN_QKV = auto() # Deepseek-OCR
V_SAM_ATTN_QKV = auto() V_SAM_ATTN_OUT = auto() # Deepseek-OCR
V_SAM_ATTN_OUT = auto() V_SAM_MLP_LIN_1 = auto() # Deepseek-OCR
V_SAM_MLP_LIN_1 = auto() V_SAM_MLP_LIN_2 = auto() # Deepseek-OCR
V_SAM_MLP_LIN_2 = auto() V_SAM_NECK = auto() # Deepseek-OCR
V_SAM_NECK = auto() V_SAM_NET_2 = auto() # Deepseek-OCR
V_SAM_NET_2 = auto() V_SAM_NET_3 = auto() # Deepseek-OCR
V_SAM_NET_3 = auto() V_ENC_EMBD_IMGNL = auto() # Deepseek-OCR
V_ENC_EMBD_VSEP = auto() # Deepseek-OCR
# audio (mtmd) # audio (mtmd)
A_ENC_EMBD_POS = auto() A_ENC_EMBD_POS = auto()
@ -1059,6 +1064,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.V_SAM_NECK: "v.sam.neck.{bid}", MODEL_TENSOR.V_SAM_NECK: "v.sam.neck.{bid}",
MODEL_TENSOR.V_SAM_NET_2: "v.sam.net_2", MODEL_TENSOR.V_SAM_NET_2: "v.sam.net_2",
MODEL_TENSOR.V_SAM_NET_3: "v.sam.net_3", MODEL_TENSOR.V_SAM_NET_3: "v.sam.net_3",
MODEL_TENSOR.V_ENC_EMBD_IMGNL: "v.image_newline_embd", # Deepseek-OCR
MODEL_TENSOR.V_ENC_EMBD_VSEP: "v.view_separator_embd", # Deepseek-OCR
# audio (mtmd) # audio (mtmd)
MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd", MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}", MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
@ -1095,6 +1102,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_ENC_EMBD_CLS, MODEL_TENSOR.V_ENC_EMBD_CLS,
MODEL_TENSOR.V_ENC_EMBD_PATCH, MODEL_TENSOR.V_ENC_EMBD_PATCH,
MODEL_TENSOR.V_ENC_EMBD_POS, MODEL_TENSOR.V_ENC_EMBD_POS,
MODEL_TENSOR.V_ENC_EMBD_IMGNL,
MODEL_TENSOR.V_ENC_EMBD_VSEP,
MODEL_TENSOR.V_ENC_INPUT_NORM, MODEL_TENSOR.V_ENC_INPUT_NORM,
MODEL_TENSOR.V_ENC_ATTN_QKV, MODEL_TENSOR.V_ENC_ATTN_QKV,
MODEL_TENSOR.V_ENC_ATTN_Q, MODEL_TENSOR.V_ENC_ATTN_Q,
@ -1137,6 +1146,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_MM_GATE, MODEL_TENSOR.V_MM_GATE,
MODEL_TENSOR.V_TOK_BOI, MODEL_TENSOR.V_TOK_BOI,
MODEL_TENSOR.V_TOK_EOI, MODEL_TENSOR.V_TOK_EOI,
MODEL_TENSOR.V_SAM_POS_EMBD,
MODEL_TENSOR.V_SAM_PATCH_EMBD,
MODEL_TENSOR.V_SAM_PRE_NORM,
MODEL_TENSOR.V_SAM_POST_NORM,
MODEL_TENSOR.V_SAM_ATTN_POS_H,
MODEL_TENSOR.V_SAM_ATTN_POS_W,
MODEL_TENSOR.V_SAM_ATTN_QKV,
MODEL_TENSOR.V_SAM_ATTN_OUT,
MODEL_TENSOR.V_SAM_MLP_LIN_1,
MODEL_TENSOR.V_SAM_MLP_LIN_2,
MODEL_TENSOR.V_SAM_NECK,
MODEL_TENSOR.V_SAM_NET_2,
MODEL_TENSOR.V_SAM_NET_3,
# audio # audio
MODEL_TENSOR.A_ENC_EMBD_POS, MODEL_TENSOR.A_ENC_EMBD_POS,
MODEL_TENSOR.A_ENC_CONV1D, MODEL_TENSOR.A_ENC_CONV1D,

View File

@ -1077,6 +1077,12 @@ class GGUFWriter:
def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None: def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None:
self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers) self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers)
def add_vision_sam_layers_count(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.SAM.BLOCK_COUNT, value)
def add_vision_sam_embedding_length(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.SAM.EMBEDDING_LENGTH, value)
# audio models # audio models
def add_audio_projection_dim(self, value: int) -> None: def add_audio_projection_dim(self, value: int) -> None:

View File

@ -1179,6 +1179,7 @@ class TensorNameMap:
MODEL_TENSOR.V_MMPROJ_FC: ( MODEL_TENSOR.V_MMPROJ_FC: (
"model.connector.modality_projection.proj", # SmolVLM "model.connector.modality_projection.proj", # SmolVLM
"model.vision.linear_proj.linear_proj", # cogvlm "model.vision.linear_proj.linear_proj", # cogvlm
"model.projector.layers", # Deepseek-OCR
), ),
MODEL_TENSOR.V_MMPROJ_MLP: ( MODEL_TENSOR.V_MMPROJ_MLP: (
@ -1197,6 +1198,7 @@ class TensorNameMap:
"model.vision_tower.embeddings.cls_token", # Intern-S1 "model.vision_tower.embeddings.cls_token", # Intern-S1
"vision_model.class_embedding", # llama 4 "vision_model.class_embedding", # llama 4
"model.vision.patch_embedding.cls_embedding", # cogvlm "model.vision.patch_embedding.cls_embedding", # cogvlm
"model.vision_model.embeddings.class_embedding", # Deepseek-OCR
), ),
MODEL_TENSOR.V_ENC_EMBD_PATCH: ( MODEL_TENSOR.V_ENC_EMBD_PATCH: (
@ -1210,6 +1212,7 @@ class TensorNameMap:
"visual.patch_embed.proj", # qwen2vl "visual.patch_embed.proj", # qwen2vl
"vision_tower.patch_embed.proj", # kimi-vl "vision_tower.patch_embed.proj", # kimi-vl
"model.vision.patch_embedding.proj", # cogvlm "model.vision.patch_embedding.proj", # cogvlm
"model.vision_model.embeddings.patch_embedding", # Deepseek-OCR CLIP
), ),
MODEL_TENSOR.V_ENC_EMBD_POS: ( MODEL_TENSOR.V_ENC_EMBD_POS: (
@ -1223,9 +1226,18 @@ class TensorNameMap:
"model.vision.patch_embedding.position_embedding", # cogvlm "model.vision.patch_embedding.position_embedding", # cogvlm
), ),
MODEL_TENSOR.V_ENC_EMBD_IMGNL: (
"model.image_newline", # Deepseek-OCR
),
MODEL_TENSOR.V_ENC_EMBD_VSEP: (
"model.view_seperator", # Deepseek-OCR
),
MODEL_TENSOR.V_ENC_ATTN_QKV: ( MODEL_TENSOR.V_ENC_ATTN_QKV: (
"visual.blocks.{bid}.attn.qkv", # qwen3vl "visual.blocks.{bid}.attn.qkv", # qwen3vl
"model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm "model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm
"model.vision_model.transformer.layers.{bid}.self_attn.qkv_proj", # Deepseek-OCR CLIP
), ),
MODEL_TENSOR.V_ENC_ATTN_Q: ( MODEL_TENSOR.V_ENC_ATTN_Q: (
@ -1238,6 +1250,7 @@ class TensorNameMap:
"vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral "vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral
"visual.blocks.{bid}.attn.q", # qwen2vl, generated "visual.blocks.{bid}.attn.q", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated "vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated
"model.vision_model.transformer.layers.{bid}.self_attn.q_proj", # Deepseek-OCR CLIP, generated
), ),
MODEL_TENSOR.V_ENC_ATTN_Q_NORM: ( MODEL_TENSOR.V_ENC_ATTN_Q_NORM: (
@ -1255,6 +1268,7 @@ class TensorNameMap:
"vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral "vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral
"visual.blocks.{bid}.attn.k", # qwen2vl, generated "visual.blocks.{bid}.attn.k", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated "vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated
"model.vision_model.transformer.layers.{bid}.self_attn.k_proj", # Deepseek-OCR CLIP, generated
), ),
MODEL_TENSOR.V_ENC_ATTN_K_NORM: ( MODEL_TENSOR.V_ENC_ATTN_K_NORM: (
@ -1272,6 +1286,7 @@ class TensorNameMap:
"vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral "vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral
"visual.blocks.{bid}.attn.v", # qwen2vl, generated "visual.blocks.{bid}.attn.v", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated "vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated
"model.vision_model.transformer.layers.{bid}.self_attn.v_proj", # Deepseek-OCR CLIP, generated
), ),
MODEL_TENSOR.V_ENC_INPUT_NORM: ( MODEL_TENSOR.V_ENC_INPUT_NORM: (
@ -1286,6 +1301,7 @@ class TensorNameMap:
"visual.blocks.{bid}.norm1", # qwen2vl "visual.blocks.{bid}.norm1", # qwen2vl
"vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1) "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
"model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm "model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm
"model.vision_model.transformer.layers.{bid}.layer_norm1", # Deepseek-OCR CLIP
), ),
MODEL_TENSOR.V_ENC_ATTN_O: ( MODEL_TENSOR.V_ENC_ATTN_O: (
@ -1301,6 +1317,7 @@ class TensorNameMap:
"visual.blocks.{bid}.attn.proj", # qwen2vl "visual.blocks.{bid}.attn.proj", # qwen2vl
"vision_tower.encoder.blocks.{bid}.wo", # kimi-vl "vision_tower.encoder.blocks.{bid}.wo", # kimi-vl
"model.vision.transformer.layers.{bid}.attention.dense", # cogvlm "model.vision.transformer.layers.{bid}.attention.dense", # cogvlm
"model.vision_model.transformer.layers.{bid}.self_attn.out_proj", # Deepseek-OCR CLIP
), ),
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
@ -1315,6 +1332,7 @@ class TensorNameMap:
"visual.blocks.{bid}.norm2", # qwen2vl "visual.blocks.{bid}.norm2", # qwen2vl
"vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1) "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1)
"model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm "model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm
"model.vision_model.transformer.layers.{bid}.layer_norm2", # Deepseek-OCR CLIP
), ),
MODEL_TENSOR.V_ENC_FFN_UP: ( MODEL_TENSOR.V_ENC_FFN_UP: (
@ -1329,6 +1347,7 @@ class TensorNameMap:
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
"visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl "visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl
"vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1) "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
"model.vision_model.transformer.layers.{bid}.mlp.fc1", # Deepseek-OCR CLIP
"model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm "model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm
), ),
@ -1351,6 +1370,7 @@ class TensorNameMap:
"visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl "visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl
"vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1) "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
"model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm "model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm
"model.vision_model.transformer.layers.{bid}.mlp.fc2", # Deepseek-OCR CLIP
), ),
MODEL_TENSOR.V_LAYER_SCALE_1: ( MODEL_TENSOR.V_LAYER_SCALE_1: (
@ -1368,6 +1388,7 @@ class TensorNameMap:
"vision_tower.ln_pre", # pixtral-hf "vision_tower.ln_pre", # pixtral-hf
"vision_encoder.ln_pre", # pixtral "vision_encoder.ln_pre", # pixtral
"vision_model.layernorm_pre", # llama4 "vision_model.layernorm_pre", # llama4
"model.vision_model.pre_layrnorm", # Deepseek-OCR CLIP
), ),
MODEL_TENSOR.V_POST_NORM: ( MODEL_TENSOR.V_POST_NORM: (
@ -1460,11 +1481,11 @@ class TensorNameMap:
), ),
MODEL_TENSOR.V_SAM_POS_EMBD: ( MODEL_TENSOR.V_SAM_POS_EMBD: (
"model.sam_model.pos_embed" "model.sam_model.pos_embed",
), ),
MODEL_TENSOR.V_SAM_PATCH_EMBD: ( MODEL_TENSOR.V_SAM_PATCH_EMBD: (
"model.sam_model.patch_embed.proj" "model.sam_model.patch_embed.proj",
), ),
MODEL_TENSOR.V_SAM_PRE_NORM: ( MODEL_TENSOR.V_SAM_PRE_NORM: (
@ -1476,19 +1497,19 @@ class TensorNameMap:
), ),
MODEL_TENSOR.V_SAM_ATTN_POS_H: ( MODEL_TENSOR.V_SAM_ATTN_POS_H: (
"model.sam_model.blocks.{bid}.attn.rel_pos_h" "model.sam_model.blocks.{bid}.attn.rel_pos_h",
), ),
MODEL_TENSOR.V_SAM_ATTN_POS_W: ( MODEL_TENSOR.V_SAM_ATTN_POS_W: (
"model.sam_model.blocks.{bid}.attn.rel_pos_w" "model.sam_model.blocks.{bid}.attn.rel_pos_w",
), ),
MODEL_TENSOR.V_SAM_ATTN_QKV: ( MODEL_TENSOR.V_SAM_ATTN_QKV: (
"model.sam_model.blocks.{bid}.attn.qkv" "model.sam_model.blocks.{bid}.attn.qkv",
), ),
MODEL_TENSOR.V_SAM_ATTN_OUT: ( MODEL_TENSOR.V_SAM_ATTN_OUT: (
"model.sam_model.blocks.{bid}.attn.proj" "model.sam_model.blocks.{bid}.attn.proj",
), ),
MODEL_TENSOR.V_SAM_MLP_LIN_1: ( MODEL_TENSOR.V_SAM_MLP_LIN_1: (
@ -1500,15 +1521,15 @@ class TensorNameMap:
), ),
MODEL_TENSOR.V_SAM_NECK: ( MODEL_TENSOR.V_SAM_NECK: (
"model.sam_model.neck.{bid}" "model.sam_model.neck.{bid}",
), ),
MODEL_TENSOR.V_SAM_NET_2: ( MODEL_TENSOR.V_SAM_NET_2: (
"model.sam_model.net_2" "model.sam_model.net_2",
), ),
MODEL_TENSOR.V_SAM_NET_3: ( MODEL_TENSOR.V_SAM_NET_3: (
"model.sam_model.net_3" "model.sam_model.net_3",
), ),
MODEL_TENSOR.V_MM_POST_FC_NORM: ( MODEL_TENSOR.V_MM_POST_FC_NORM: (