This commit is contained in:
h9-tec 2025-12-16 16:38:07 +08:00 committed by GitHub
commit 755aeef41c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 2754 additions and 0 deletions

386
convert_nougat_to_gguf.py Normal file
View File

@ -0,0 +1,386 @@
#!/usr/bin/env python3
"""
Convert Nougat (Neural Optical Understanding for Academic Documents) model to GGUF format.
This script handles the conversion of Nougat's Swin Transformer encoder and mBART decoder.
"""
import argparse
import json
import os
import struct
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from transformers import NougatProcessor, VisionEncoderDecoderModel
# Add parent directory to path to import gguf
sys.path.append(str(Path(__file__).parent / "gguf-py"))
import gguf
# Constants for Nougat
NOUGAT_VISION_PREFIX = "vision_model"
NOUGAT_DECODER_PREFIX = "decoder"
NOUGAT_ENCODER_PREFIX = "encoder"
def parse_args():
parser = argparse.ArgumentParser(description="Convert Nougat model to GGUF format")
parser.add_argument(
"--model-id",
type=str,
default="facebook/nougat-base",
help="HuggingFace model ID or path to local model",
)
parser.add_argument(
"--output-dir",
type=str,
default="./models",
help="Output directory for GGUF files",
)
parser.add_argument(
"--quantization",
type=str,
choices=["f32", "f16", "q8_0", "q4_0", "q4_1"],
default="f16",
help="Quantization type for model weights",
)
parser.add_argument(
"--split-model",
action="store_true",
help="Split into separate vision and text GGUF files",
)
parser.add_argument(
"--vocab-only",
action="store_true",
help="Only export vocabulary/tokenizer",
)
parser.add_argument(
"--verbose",
action="store_true",
help="Verbose output during conversion",
)
return parser.parse_args()
def get_tensor_name(name: str) -> str:
"""Map Nougat tensor names to GGUF tensor names"""
# Vision model (Swin Transformer) mappings
if name.startswith("encoder.model.encoder."):
# Swin encoder layers
name = name.replace("encoder.model.encoder.", "swin.")
# Patch embedding
if "embeddings.patch_embeddings" in name:
if "projection.weight" in name:
return "swin.patch_embed.weight"
elif "projection.bias" in name:
return "swin.patch_embed.bias"
# Position embeddings
if "position_embeddings" in name:
return "swin.pos_embed"
# Layer mappings
if "layers." in name:
# Extract stage and layer indices
parts = name.split(".")
for i, part in enumerate(parts):
if part == "layers":
stage_idx = int(parts[i + 1])
if "blocks." in name:
block_idx = int(parts[parts.index("blocks") + 1])
# Attention components
if "attn.qkv" in name:
return f"swin.stage.{stage_idx}.layer.{block_idx}.attn.qkv.{'weight' if 'weight' in name else 'bias'}"
elif "attn.proj" in name:
return f"swin.stage.{stage_idx}.layer.{block_idx}.attn.proj.{'weight' if 'weight' in name else 'bias'}"
elif "norm1" in name:
return f"swin.stage.{stage_idx}.layer.{block_idx}.norm1.{'weight' if 'weight' in name else 'bias'}"
elif "norm2" in name:
return f"swin.stage.{stage_idx}.layer.{block_idx}.norm2.{'weight' if 'weight' in name else 'bias'}"
elif "mlp.fc1" in name:
return f"swin.stage.{stage_idx}.layer.{block_idx}.mlp.fc1.{'weight' if 'weight' in name else 'bias'}"
elif "mlp.fc2" in name:
return f"swin.stage.{stage_idx}.layer.{block_idx}.mlp.fc2.{'weight' if 'weight' in name else 'bias'}"
# Downsample layers
elif "downsample" in name:
if "norm" in name:
return f"swin.stage.{stage_idx}.downsample.norm.{'weight' if 'weight' in name else 'bias'}"
elif "reduction" in name:
return f"swin.stage.{stage_idx}.downsample.reduction.weight"
# Decoder model (mBART) mappings
elif name.startswith("decoder.model."):
name = name.replace("decoder.model.", "")
# Token and position embeddings
if name == "shared.weight":
return "token_embd.weight"
elif name == "decoder.embed_positions.weight":
return "position_embd.weight"
# Decoder layers
if "decoder.layers." in name:
layer_idx = int(name.split(".")[2])
# Self-attention
if "self_attn.q_proj" in name:
return f"blk.{layer_idx}.attn_q.weight"
elif "self_attn.k_proj" in name:
return f"blk.{layer_idx}.attn_k.weight"
elif "self_attn.v_proj" in name:
return f"blk.{layer_idx}.attn_v.weight"
elif "self_attn.out_proj" in name:
return f"blk.{layer_idx}.attn_o.weight"
elif "self_attn_layer_norm" in name:
return f"blk.{layer_idx}.attn_norm.{'weight' if 'weight' in name else 'bias'}"
# Cross-attention
elif "encoder_attn.q_proj" in name:
return f"blk.{layer_idx}.attn_q_cross.weight"
elif "encoder_attn.k_proj" in name:
return f"blk.{layer_idx}.attn_k_cross.weight"
elif "encoder_attn.v_proj" in name:
return f"blk.{layer_idx}.attn_v_cross.weight"
elif "encoder_attn.out_proj" in name:
return f"blk.{layer_idx}.attn_o_cross.weight"
elif "encoder_attn_layer_norm" in name:
return f"blk.{layer_idx}.attn_norm_cross.{'weight' if 'weight' in name else 'bias'}"
# FFN
elif "fc1" in name:
return f"blk.{layer_idx}.ffn_up.weight"
elif "fc2" in name:
return f"blk.{layer_idx}.ffn_down.weight"
elif "final_layer_norm" in name:
return f"blk.{layer_idx}.ffn_norm.{'weight' if 'weight' in name else 'bias'}"
# Output layers
elif "decoder.layer_norm" in name:
return f"output_norm.{'weight' if 'weight' in name else 'bias'}"
elif "lm_head" in name:
return "output.weight"
# Encoder layers (for encoder-only export)
elif name.startswith("encoder."):
name = name.replace("encoder.", "enc.")
# Similar mappings but with enc. prefix
return f"enc.{name}"
# Default: return original name
return name
def convert_swin_encoder(model_dict: Dict[str, torch.Tensor], gguf_writer: gguf.GGUFWriter, args):
"""Convert Swin Transformer encoder weights to GGUF format"""
print("Converting Swin Transformer encoder...")
# Write Swin hyperparameters
swin_config = {
"window_size": 7,
"patch_size": 4,
"image_size": 384, # Default for Nougat
"hidden_dim": 96,
"depths": [2, 2, 6, 2],
"num_heads": [3, 6, 12, 24],
"mlp_ratio": 4.0,
"norm_eps": 1e-5,
}
gguf_writer.add_string("swin.type", "swin_transformer")
gguf_writer.add_int32("swin.window_size", swin_config["window_size"])
gguf_writer.add_int32("swin.patch_size", swin_config["patch_size"])
gguf_writer.add_int32("swin.image_size", swin_config["image_size"])
gguf_writer.add_int32("swin.hidden_dim", swin_config["hidden_dim"])
gguf_writer.add_float32("swin.mlp_ratio", swin_config["mlp_ratio"])
gguf_writer.add_float32("swin.norm_eps", swin_config["norm_eps"])
# Convert encoder weights
encoder_tensors = {k: v for k, v in model_dict.items() if k.startswith("encoder.")}
for name, tensor in encoder_tensors.items():
gguf_name = get_tensor_name(name)
if args.verbose:
print(f" {name} -> {gguf_name} {list(tensor.shape)}")
# Convert to appropriate dtype
if args.quantization == "f32":
data = tensor.float().cpu().numpy()
elif args.quantization == "f16":
data = tensor.half().cpu().numpy()
else:
# Quantization would be applied here
data = tensor.float().cpu().numpy()
gguf_writer.add_tensor(gguf_name, data)
print(f" Converted {len(encoder_tensors)} encoder tensors")
def convert_mbart_decoder(model_dict: Dict[str, torch.Tensor], gguf_writer: gguf.GGUFWriter, args):
"""Convert mBART decoder weights to GGUF format"""
print("Converting mBART decoder...")
# Write mBART architecture info
gguf_writer.add_string("general.architecture", "mbart")
# Convert decoder weights
decoder_tensors = {k: v for k, v in model_dict.items() if k.startswith("decoder.")}
for name, tensor in decoder_tensors.items():
gguf_name = get_tensor_name(name)
if args.verbose:
print(f" {name} -> {gguf_name} {list(tensor.shape)}")
# Convert to appropriate dtype
if args.quantization == "f32":
data = tensor.float().cpu().numpy()
elif args.quantization == "f16":
data = tensor.half().cpu().numpy()
else:
# Quantization would be applied here
data = tensor.float().cpu().numpy()
gguf_writer.add_tensor(gguf_name, data)
print(f" Converted {len(decoder_tensors)} decoder tensors")
def convert_tokenizer(processor, gguf_writer: gguf.GGUFWriter, args):
"""Convert Nougat tokenizer/processor to GGUF format"""
print("Converting tokenizer...")
tokenizer = processor.tokenizer
vocab = tokenizer.get_vocab()
# Write tokenizer metadata
gguf_writer.add_string("tokenizer.model", "mbart")
gguf_writer.add_int32("tokenizer.vocab_size", len(vocab))
# Add special tokens
special_tokens = {
"bos": tokenizer.bos_token,
"eos": tokenizer.eos_token,
"unk": tokenizer.unk_token,
"pad": tokenizer.pad_token,
}
for key, token in special_tokens.items():
if token:
gguf_writer.add_string(f"tokenizer.{key}_token", token)
gguf_writer.add_int32(f"tokenizer.{key}_token_id", tokenizer.convert_tokens_to_ids(token))
# Add vocabulary
tokens = []
scores = []
token_types = []
for token, token_id in sorted(vocab.items(), key=lambda x: x[1]):
tokens.append(token.encode("utf-8"))
scores.append(0.0) # Dummy scores for now
token_types.append(1 if token in tokenizer.all_special_tokens else 0)
gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(token_types)
print(f" Vocabulary size: {len(vocab)}")
def main():
args = parse_args()
# Create output directory
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Loading Nougat model from {args.model_id}...")
# Load model and processor
processor = NougatProcessor.from_pretrained(args.model_id)
model = VisionEncoderDecoderModel.from_pretrained(args.model_id)
# Get model state dict
state_dict = model.state_dict()
if args.split_model:
# Create separate files for vision and text models
# Vision model (Swin encoder)
vision_output = output_dir / "nougat-vision.gguf"
print(f"\nCreating vision model: {vision_output}")
vision_writer = gguf.GGUFWriter(str(vision_output), "nougat-vision")
vision_writer.add_string("general.name", "Nougat Vision Model (Swin)")
vision_writer.add_string("general.description", "Swin Transformer encoder for Nougat OCR")
vision_writer.add_string("general.architecture", "swin")
convert_swin_encoder(state_dict, vision_writer, args)
vision_writer.write_header_to_file()
vision_writer.write_kv_data_to_file()
vision_writer.write_tensors_to_file()
vision_writer.close()
# Text model (mBART decoder)
text_output = output_dir / "nougat-text.gguf"
print(f"\nCreating text model: {text_output}")
text_writer = gguf.GGUFWriter(str(text_output), "nougat-text")
text_writer.add_string("general.name", "Nougat Text Model (mBART)")
text_writer.add_string("general.description", "mBART decoder for Nougat OCR")
convert_mbart_decoder(state_dict, text_writer, args)
convert_tokenizer(processor, text_writer, args)
text_writer.write_header_to_file()
text_writer.write_kv_data_to_file()
text_writer.write_tensors_to_file()
text_writer.close()
else:
# Create single combined model file
output_file = output_dir / "nougat-combined.gguf"
print(f"\nCreating combined model: {output_file}")
writer = gguf.GGUFWriter(str(output_file), "nougat")
writer.add_string("general.name", "Nougat OCR Model")
writer.add_string("general.description", "Neural Optical Understanding for Academic Documents")
writer.add_string("general.architecture", "nougat")
# Add both encoder and decoder
convert_swin_encoder(state_dict, writer, args)
convert_mbart_decoder(state_dict, writer, args)
if not args.vocab_only:
convert_tokenizer(processor, writer, args)
writer.write_header_to_file()
writer.write_kv_data_to_file()
writer.write_tensors_to_file()
writer.close()
print("\nConversion complete!")
# Print model statistics
total_params = sum(p.numel() for p in model.parameters())
print(f"\nModel statistics:")
print(f" Total parameters: {total_params:,}")
print(f" Encoder parameters: {sum(p.numel() for n, p in model.named_parameters() if 'encoder' in n):,}")
print(f" Decoder parameters: {sum(p.numel() for n, p in model.named_parameters() if 'decoder' in n):,}")
if args.quantization != "f32":
print(f" Quantization: {args.quantization}")
if __name__ == "__main__":
main()

View File

@ -410,6 +410,8 @@ class MODEL_ARCH(IntEnum):
BITNET = auto() BITNET = auto()
T5 = auto() T5 = auto()
T5ENCODER = auto() T5ENCODER = auto()
MBART = auto()
MBARTENCODER = auto()
JAIS = auto() JAIS = auto()
NEMOTRON = auto() NEMOTRON = auto()
NEMOTRON_H = auto() NEMOTRON_H = auto()
@ -784,6 +786,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.BITNET: "bitnet", MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5", MODEL_ARCH.T5: "t5",
MODEL_ARCH.T5ENCODER: "t5encoder", MODEL_ARCH.T5ENCODER: "t5encoder",
MODEL_ARCH.MBART: "mbart",
MODEL_ARCH.MBARTENCODER: "mbartencoder",
MODEL_ARCH.JAIS: "jais", MODEL_ARCH.JAIS: "jais",
MODEL_ARCH.NEMOTRON: "nemotron", MODEL_ARCH.NEMOTRON: "nemotron",
MODEL_ARCH.NEMOTRON_H: "nemotron_h", MODEL_ARCH.NEMOTRON_H: "nemotron_h",
@ -2485,6 +2489,48 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ENC_FFN_UP, MODEL_TENSOR.ENC_FFN_UP,
MODEL_TENSOR.ENC_OUTPUT_NORM, MODEL_TENSOR.ENC_OUTPUT_NORM,
], ],
MODEL_ARCH.MBART: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.POS_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_NORM_CROSS,
MODEL_TENSOR.ATTN_Q_CROSS,
MODEL_TENSOR.ATTN_K_CROSS,
MODEL_TENSOR.ATTN_V_CROSS,
MODEL_TENSOR.ATTN_OUT_CROSS,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ENC_ATTN_NORM,
MODEL_TENSOR.ENC_ATTN_Q,
MODEL_TENSOR.ENC_ATTN_K,
MODEL_TENSOR.ENC_ATTN_V,
MODEL_TENSOR.ENC_ATTN_OUT,
MODEL_TENSOR.ENC_FFN_NORM,
MODEL_TENSOR.ENC_FFN_DOWN,
MODEL_TENSOR.ENC_FFN_UP,
MODEL_TENSOR.ENC_OUTPUT_NORM,
],
MODEL_ARCH.MBARTENCODER: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.POS_EMBD,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ENC_ATTN_NORM,
MODEL_TENSOR.ENC_ATTN_Q,
MODEL_TENSOR.ENC_ATTN_K,
MODEL_TENSOR.ENC_ATTN_V,
MODEL_TENSOR.ENC_ATTN_OUT,
MODEL_TENSOR.ENC_FFN_NORM,
MODEL_TENSOR.ENC_FFN_DOWN,
MODEL_TENSOR.ENC_FFN_UP,
MODEL_TENSOR.ENC_OUTPUT_NORM,
],
MODEL_ARCH.JAIS: [ MODEL_ARCH.JAIS: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,

View File

@ -72,6 +72,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5, "t5" },
{ LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_T5ENCODER, "t5encoder" },
{ LLM_ARCH_MBART, "mbart" },
{ LLM_ARCH_MBARTENCODER, "mbartencoder" },
{ LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_JAIS, "jais" },
{ LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_NEMOTRON, "nemotron" },
{ LLM_ARCH_NEMOTRON_H, "nemotron_h" }, { LLM_ARCH_NEMOTRON_H, "nemotron_h" },
@ -1706,6 +1708,54 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" }, { LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" },
}, },
}, },
{
LLM_ARCH_MBART,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_POS_EMBD, "position_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_o" },
{ LLM_TENSOR_ATTN_NORM_CROSS, "blk.%d.attn_norm_cross" },
{ LLM_TENSOR_ATTN_Q_CROSS, "blk.%d.attn_q_cross" },
{ LLM_TENSOR_ATTN_K_CROSS, "blk.%d.attn_k_cross" },
{ LLM_TENSOR_ATTN_V_CROSS, "blk.%d.attn_v_cross" },
{ LLM_TENSOR_ATTN_OUT_CROSS, "blk.%d.attn_o_cross" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" },
{ LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" },
{ LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" },
{ LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" },
{ LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" },
{ LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" },
{ LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" },
{ LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" },
},
},
{
LLM_ARCH_MBARTENCODER,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_POS_EMBD, "position_embd" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
{ LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" },
{ LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" },
{ LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" },
{ LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" },
{ LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" },
{ LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" },
{ LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" },
{ LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" },
},
},
{ {
LLM_ARCH_JAIS, LLM_ARCH_JAIS,
{ {

View File

@ -76,6 +76,8 @@ enum llm_arch {
LLM_ARCH_BITNET, LLM_ARCH_BITNET,
LLM_ARCH_T5, LLM_ARCH_T5,
LLM_ARCH_T5ENCODER, LLM_ARCH_T5ENCODER,
LLM_ARCH_MBART,
LLM_ARCH_MBARTENCODER,
LLM_ARCH_JAIS, LLM_ARCH_JAIS,
LLM_ARCH_NEMOTRON, LLM_ARCH_NEMOTRON,
LLM_ARCH_NEMOTRON_H, LLM_ARCH_NEMOTRON_H,
@ -326,6 +328,11 @@ enum llm_tensor {
LLM_TENSOR_ATTN_OUT, LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_ATTN_NORM, LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_NORM_2, LLM_TENSOR_ATTN_NORM_2,
LLM_TENSOR_ATTN_NORM_CROSS,
LLM_TENSOR_ATTN_Q_CROSS,
LLM_TENSOR_ATTN_K_CROSS,
LLM_TENSOR_ATTN_V_CROSS,
LLM_TENSOR_ATTN_OUT_CROSS,
LLM_TENSOR_ATTN_OUT_NORM, LLM_TENSOR_ATTN_OUT_NORM,
LLM_TENSOR_ATTN_POST_NORM, LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_ATTN_ROT_EMBD, LLM_TENSOR_ATTN_ROT_EMBD,

162
src/models/mbart-dec.cpp Normal file
View File

@ -0,0 +1,162 @@
#include "models.h"
llm_build_mbart_dec::llm_build_mbart_dec(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
ggml_tensor * cur;
ggml_tensor * inpL;
// mBART uses learned positional embeddings
inpL = build_inp_embd(model.tok_embd);
// Add positional embeddings
ggml_tensor * pos_embd = build_inp_pos_embd();
if (pos_embd) {
inpL = ggml_add(ctx0, inpL, pos_embd);
cb(inpL, "pos_embd", -1);
}
// Get encoder embeddings for cross-attention
ggml_tensor * embd_enc = build_inp_cross_embd();
const int64_t n_outputs_enc = embd_enc->ne[1];
// Layer normalization before the first layer (mBART characteristic)
cur = build_norm(inpL,
model.output_norm, NULL,
LLM_NORM, -1);
cb(cur, "input_norm", -1);
inpL = cur;
auto * inp_attn_self = build_attn_inp_kv();
auto * inp_attn_cross = build_attn_inp_cross();
ggml_tensor * inp_out_ids = build_inp_out_ids();
const int64_t dec_n_layer = hparams.dec_n_layer;
for (int il = 0; il < dec_n_layer; ++il) {
ggml_tensor * inpSA = inpL;
// self-attention
{
// norm before attention
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM, il);
cb(cur, "attn_norm", il);
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
// mBART uses standard scaled dot-product attention
cur = build_attn(inp_attn_self,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf((float)n_embd_head), il);
cb(cur, "kqv_out", il);
}
// residual connection
cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "self_attn_out", il);
ggml_tensor * inpCA = cur;
// cross-attention
{
// norm before cross-attention
cur = build_norm(cur,
model.layers[il].attn_norm_cross, NULL,
LLM_NORM, il);
cb(cur, "attn_norm_cross", il);
// Q from decoder, K and V from encoder
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_cross, cur);
cb(Qcur, "Qcur_cross", il);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_cross, embd_enc);
cb(Kcur, "Kcur_cross", il);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_cross, embd_enc);
cb(Vcur, "Vcur_cross", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_outputs_enc);
cur = build_attn(inp_attn_cross,
model.layers[il].wo_cross, model.layers[il].bo_cross,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf((float)n_embd_head), il);
cb(cur, "kqv_cross_out", il);
}
if (il == dec_n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
}
// residual connection
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA);
cb(ffn_inp, "cross_attn_out", il);
// feed-forward network
{
// norm before FFN
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM, il);
cb(cur, "ffn_norm", il);
// mBART uses GELU activation
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
NULL, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_GELU,
LLM_FFN_SEQ,
il);
cb(cur, "ffn_out", il);
}
// residual connection
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "layer_out", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cb(cur, "result_embd", -1);
// Final layer normalization
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head for generation
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}

114
src/models/mbart-enc.cpp Normal file
View File

@ -0,0 +1,114 @@
#include "models.h"
llm_build_mbart_enc::llm_build_mbart_enc(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
ggml_tensor * cur;
ggml_tensor * inpL;
// mBART uses learned positional embeddings
inpL = build_inp_embd(model.tok_embd);
// Add positional embeddings for mBART
ggml_tensor * pos_embd = build_inp_pos_embd();
if (pos_embd) {
inpL = ggml_add(ctx0, inpL, pos_embd);
cb(inpL, "pos_embd", -1);
}
// Layer normalization before the first layer (mBART characteristic)
cur = build_norm(inpL,
model.output_norm_enc, NULL,
LLM_NORM, -1);
cb(cur, "input_norm", -1);
inpL = cur;
auto * inp_attn = build_attn_inp_no_cache();
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
// self-attention (mBART uses pre-norm)
{
// norm before attention
cur = build_norm(inpL,
model.layers[il].attn_norm_enc, NULL,
LLM_NORM, il);
cb(cur, "attn_norm", il);
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_enc, cur);
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_enc, cur);
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_enc, cur);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
// mBART uses standard scaled dot-product attention without relative position bias
cur = build_attn(inp_attn,
model.layers[il].wo_enc, model.layers[il].bo_enc,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf((float)n_embd_head), il);
cb(cur, "kqv_out", il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
// residual connection
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "attn_out", il);
// feed-forward network
{
// norm before FFN
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm_enc, NULL,
LLM_NORM, il);
cb(cur, "ffn_norm", il);
// mBART uses GELU activation
cur = build_ffn(cur,
model.layers[il].ffn_up_enc, NULL, NULL,
NULL, NULL, NULL,
model.layers[il].ffn_down_enc, NULL, NULL,
NULL,
LLM_FFN_GELU,
LLM_FFN_SEQ,
il);
cb(cur, "ffn_out", il);
}
// residual connection
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cb(cur, "result_embd", -1);
// Final layer normalization
cur = build_norm(cur,
model.output_norm_enc, NULL,
LLM_NORM, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
ggml_build_forward_expand(gf, cur);
}

View File

@ -535,6 +535,14 @@ struct llm_build_t5_enc : public llm_graph_context {
llm_build_t5_enc(const llama_model & model, const llm_graph_params & params); llm_build_t5_enc(const llama_model & model, const llm_graph_params & params);
}; };
struct llm_build_mbart_enc : public llm_graph_context {
llm_build_mbart_enc(const llama_model & model, const llm_graph_params & params);
};
struct llm_build_mbart_dec : public llm_graph_context {
llm_build_mbart_dec(const llama_model & model, const llm_graph_params & params);
};
struct llm_build_wavtokenizer_dec : public llm_graph_context { struct llm_build_wavtokenizer_dec : public llm_graph_context {
llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params); llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params);
}; };

View File

@ -29,6 +29,7 @@ else()
add_subdirectory(tokenize) add_subdirectory(tokenize)
add_subdirectory(tts) add_subdirectory(tts)
add_subdirectory(mtmd) add_subdirectory(mtmd)
add_subdirectory(nougat)
if (GGML_RPC) if (GGML_RPC)
add_subdirectory(rpc) add_subdirectory(rpc)
endif() endif()

View File

@ -0,0 +1,386 @@
#include "clip.h"
#include "swin.h"
#include <cassert>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <vector>
#include <string>
#include <memory>
#include <algorithm>
// External image loading library integration (stb_image)
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
#ifdef _WIN32
#include <windows.h>
#else
#include <sys/stat.h>
#endif
// Document-specific preprocessing parameters for Nougat
struct nougat_preprocess_params {
int target_width = 896; // Nougat uses different resolution than standard vision models
int target_height = 1344; // Optimized for document aspect ratio
float mean[3] = {0.485f, 0.456f, 0.406f}; // ImageNet normalization
float std[3] = {0.229f, 0.224f, 0.225f};
bool center_crop = false; // Documents should not be center-cropped
bool maintain_aspect = true; // Important for documents
int patch_size = 4; // Swin Transformer patch size
};
// Structure to hold document metadata
struct document_metadata {
int original_width;
int original_height;
int num_pages;
std::string format; // PDF, PNG, JPG, etc.
float dpi;
};
// Preprocess a single document image for Nougat
static bool preprocess_document_image(
const uint8_t* img_data,
int width,
int height,
int channels,
const nougat_preprocess_params& params,
std::vector<float>& output) {
// Calculate scaling to fit target dimensions while maintaining aspect ratio
float scale_w = static_cast<float>(params.target_width) / width;
float scale_h = static_cast<float>(params.target_height) / height;
float scale = params.maintain_aspect ? std::min(scale_w, scale_h) : 1.0f;
int new_width = static_cast<int>(width * scale);
int new_height = static_cast<int>(height * scale);
// Ensure dimensions are divisible by patch size
new_width = (new_width / params.patch_size) * params.patch_size;
new_height = (new_height / params.patch_size) * params.patch_size;
// Resize image using bilinear interpolation
std::vector<uint8_t> resized_img(new_width * new_height * 3);
for (int y = 0; y < new_height; y++) {
for (int x = 0; x < new_width; x++) {
float src_x = x / scale;
float src_y = y / scale;
int x0 = static_cast<int>(src_x);
int y0 = static_cast<int>(src_y);
int x1 = std::min(x0 + 1, width - 1);
int y1 = std::min(y0 + 1, height - 1);
float fx = src_x - x0;
float fy = src_y - y0;
for (int c = 0; c < 3; c++) {
float v00 = img_data[(y0 * width + x0) * channels + c];
float v10 = img_data[(y0 * width + x1) * channels + c];
float v01 = img_data[(y1 * width + x0) * channels + c];
float v11 = img_data[(y1 * width + x1) * channels + c];
float v0 = v00 * (1 - fx) + v10 * fx;
float v1 = v01 * (1 - fx) + v11 * fx;
float v = v0 * (1 - fy) + v1 * fy;
resized_img[(y * new_width + x) * 3 + c] = static_cast<uint8_t>(v);
}
}
}
// Pad to target size if needed
int pad_left = (params.target_width - new_width) / 2;
int pad_top = (params.target_height - new_height) / 2;
output.resize(params.target_width * params.target_height * 3);
// Initialize with padding (white background for documents)
std::fill(output.begin(), output.end(), 1.0f);
// Copy resized image to output with normalization
for (int y = 0; y < new_height; y++) {
for (int x = 0; x < new_width; x++) {
int out_x = x + pad_left;
int out_y = y + pad_top;
if (out_x >= 0 && out_x < params.target_width &&
out_y >= 0 && out_y < params.target_height) {
for (int c = 0; c < 3; c++) {
float pixel = resized_img[(y * new_width + x) * 3 + c] / 255.0f;
pixel = (pixel - params.mean[c]) / params.std[c];
output[(out_y * params.target_width + out_x) * 3 + c] = pixel;
}
}
}
}
return true;
}
// Load and preprocess a document file (supports various formats)
bool nougat_preprocess_document_file(
const std::string& filename,
nougat_preprocess_params& params,
std::vector<std::vector<float>>& page_outputs,
document_metadata& metadata) {
// Check file extension
std::string ext = filename.substr(filename.find_last_of(".") + 1);
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
metadata.format = ext;
if (ext == "pdf") {
// PDF processing would require a PDF library like poppler or mupdf
// For now, we'll return an error for PDF files
fprintf(stderr, "PDF processing not yet implemented. Please convert to image format.\n");
return false;
}
// Load image using stb_image
int width, height, channels;
unsigned char* img_data = stbi_load(filename.c_str(), &width, &height, &channels, 3);
if (!img_data) {
fprintf(stderr, "Failed to load image: %s\n", filename.c_str());
return false;
}
metadata.original_width = width;
metadata.original_height = height;
metadata.num_pages = 1; // Single image
metadata.dpi = 300.0f; // Assume standard document DPI
// Preprocess the image
std::vector<float> output;
bool success = preprocess_document_image(
img_data, width, height, 3, params, output);
if (success) {
page_outputs.push_back(output);
}
stbi_image_free(img_data);
return success;
}
// Batch preprocessing for multiple document pages
bool nougat_preprocess_document_batch(
const std::vector<std::string>& filenames,
nougat_preprocess_params& params,
std::vector<std::vector<float>>& outputs) {
outputs.clear();
outputs.reserve(filenames.size());
for (const auto& filename : filenames) {
document_metadata metadata;
std::vector<std::vector<float>> page_outputs;
if (!nougat_preprocess_document_file(filename, params, page_outputs, metadata)) {
fprintf(stderr, "Failed to preprocess: %s\n", filename.c_str());
continue;
}
// Add all pages from this document
for (auto& page : page_outputs) {
outputs.push_back(std::move(page));
}
}
return !outputs.empty();
}
// Apply document-specific augmentations
void nougat_augment_document(
std::vector<float>& image_data,
int width,
int height,
bool random_rotation = false,
bool deskew = true,
bool denoise = true) {
// Document deskewing (straighten tilted scans)
if (deskew) {
// Simplified deskew - would need proper implementation
// using Hough transform or similar technique
}
// Denoising for scanned documents
if (denoise) {
// Apply median filter or similar denoising
// Simplified implementation
std::vector<float> temp = image_data;
for (int y = 1; y < height - 1; y++) {
for (int x = 1; x < width - 1; x++) {
for (int c = 0; c < 3; c++) {
std::vector<float> neighborhood;
// Collect 3x3 neighborhood
for (int dy = -1; dy <= 1; dy++) {
for (int dx = -1; dx <= 1; dx++) {
int idx = ((y + dy) * width + (x + dx)) * 3 + c;
neighborhood.push_back(temp[idx]);
}
}
// Median filter
std::sort(neighborhood.begin(), neighborhood.end());
image_data[(y * width + x) * 3 + c] = neighborhood[4];
}
}
}
}
// Random rotation for augmentation during training
if (random_rotation) {
// Apply small random rotation (-5 to +5 degrees)
// Would need proper rotation implementation
}
}
// Extract text regions from document for focused processing
struct text_region {
int x, y, width, height;
float confidence;
};
std::vector<text_region> nougat_detect_text_regions(
const std::vector<float>& image_data,
int width,
int height) {
std::vector<text_region> regions;
// Simple text detection based on connected components
// This would need a proper implementation using:
// - Edge detection
// - Connected component analysis
// - Text/non-text classification
// For now, return the whole image as a single region
text_region full_page;
full_page.x = 0;
full_page.y = 0;
full_page.width = width;
full_page.height = height;
full_page.confidence = 1.0f;
regions.push_back(full_page);
return regions;
}
// Enhanced preprocessing for mathematical formulas
void nougat_preprocess_math_regions(
std::vector<float>& image_data,
int width,
int height,
const std::vector<text_region>& math_regions) {
// Apply special preprocessing for mathematical content
for (const auto& region : math_regions) {
// Enhance contrast for mathematical symbols
for (int y = region.y; y < region.y + region.height; y++) {
for (int x = region.x; x < region.x + region.width; x++) {
for (int c = 0; c < 3; c++) {
int idx = (y * width + x) * 3 + c;
float& pixel = image_data[idx];
// Increase contrast
pixel = (pixel - 0.5f) * 1.2f + 0.5f;
pixel = std::max(0.0f, std::min(1.0f, pixel));
}
}
}
}
}
// Table detection and preprocessing
struct table_region {
text_region bounds;
int rows, cols;
std::vector<text_region> cells;
};
std::vector<table_region> nougat_detect_tables(
const std::vector<float>& image_data,
int width,
int height) {
std::vector<table_region> tables;
// Table detection would require:
// - Line detection (horizontal and vertical)
// - Grid structure analysis
// - Cell boundary detection
// Placeholder implementation
return tables;
}
// Main preprocessing pipeline for Nougat OCR
extern "C" bool nougat_preprocess_pipeline(
const char* input_path,
float** output_data,
int* output_width,
int* output_height,
int* num_pages) {
nougat_preprocess_params params;
std::vector<std::vector<float>> page_outputs;
document_metadata metadata;
// Load and preprocess document
if (!nougat_preprocess_document_file(
input_path, params, page_outputs, metadata)) {
return false;
}
// Apply document-specific processing
for (auto& page : page_outputs) {
// Detect text regions
auto text_regions = nougat_detect_text_regions(
page, params.target_width, params.target_height);
// Apply augmentations
nougat_augment_document(
page, params.target_width, params.target_height,
false, true, true);
// Detect and process mathematical regions
// (would need actual math detection)
// nougat_preprocess_math_regions(page, width, height, math_regions);
}
// Prepare output
if (!page_outputs.empty()) {
*output_width = params.target_width;
*output_height = params.target_height;
*num_pages = page_outputs.size();
// Allocate and copy data
size_t total_size = params.target_width * params.target_height * 3 * page_outputs.size();
*output_data = new float[total_size];
size_t offset = 0;
for (const auto& page : page_outputs) {
std::copy(page.begin(), page.end(), *output_data + offset);
offset += page.size();
}
return true;
}
return false;
}
// Cleanup function
extern "C" void nougat_preprocess_cleanup(float* data) {
delete[] data;
}

View File

@ -0,0 +1,400 @@
#!/usr/bin/env python3
"""
Nougat Model Surgery Script
Splits the Nougat model into separate vision encoder (Swin) and text decoder (mBART) components.
Also creates the multimodal projector that connects them.
"""
import argparse
import json
import os
import sys
from pathlib import Path
from typing import Dict, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from transformers import VisionEncoderDecoderModel, NougatProcessor
# Add parent directory to import gguf
sys.path.append(str(Path(__file__).parent.parent / "gguf-py"))
import gguf
class NougatModelSurgeon:
"""Handles splitting and converting Nougat model components"""
def __init__(self, model_id: str, output_dir: str, verbose: bool = False):
self.model_id = model_id
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.verbose = verbose
# Load the model
print(f"Loading Nougat model from {model_id}...")
self.model = VisionEncoderDecoderModel.from_pretrained(model_id)
self.processor = NougatProcessor.from_pretrained(model_id)
def extract_vision_encoder(self) -> Dict[str, torch.Tensor]:
"""Extract Swin Transformer vision encoder weights"""
print("Extracting vision encoder (Swin Transformer)...")
vision_dict = {}
encoder = self.model.encoder
# Get all encoder parameters
for name, param in encoder.named_parameters():
# Map to our Swin naming convention
mapped_name = self._map_swin_tensor_name(name)
vision_dict[mapped_name] = param.detach().cpu()
if self.verbose:
print(f" {name} -> {mapped_name} {list(param.shape)}")
print(f" Extracted {len(vision_dict)} vision encoder tensors")
return vision_dict
def extract_text_decoder(self) -> Dict[str, torch.Tensor]:
"""Extract mBART text decoder weights"""
print("Extracting text decoder (mBART)...")
decoder_dict = {}
decoder = self.model.decoder
# Get all decoder parameters
for name, param in decoder.named_parameters():
# Map to our mBART naming convention
mapped_name = self._map_mbart_tensor_name(name)
decoder_dict[mapped_name] = param.detach().cpu()
if self.verbose:
print(f" {name} -> {mapped_name} {list(param.shape)}")
print(f" Extracted {len(decoder_dict)} text decoder tensors")
return decoder_dict
def extract_projector(self) -> Dict[str, torch.Tensor]:
"""Extract multimodal projector that connects vision and text models"""
print("Extracting multimodal projector...")
projector_dict = {}
# In Nougat, the projection happens through the decoder's cross-attention
# We need to extract the projection matrices that connect encoder outputs to decoder
# Look for cross-attention weights in decoder
for name, param in self.model.decoder.named_parameters():
if "encoder_attn" in name:
# These are the cross-attention weights that project from vision to text
mapped_name = self._map_projector_tensor_name(name)
projector_dict[mapped_name] = param.detach().cpu()
if self.verbose:
print(f" {name} -> {mapped_name} {list(param.shape)}")
# If there's a specific projection layer between encoder and decoder
if hasattr(self.model, "enc_to_dec_proj"):
projector_dict["mm.projector.weight"] = self.model.enc_to_dec_proj.weight.detach().cpu()
if hasattr(self.model.enc_to_dec_proj, "bias"):
projector_dict["mm.projector.bias"] = self.model.enc_to_dec_proj.bias.detach().cpu()
print(f" Extracted {len(projector_dict)} projector tensors")
return projector_dict
def _map_swin_tensor_name(self, name: str) -> str:
"""Map HuggingFace Swin tensor names to our convention"""
# Remove model prefix
if name.startswith("model.encoder."):
name = name[len("model.encoder."):]
elif name.startswith("encoder."):
name = name[len("encoder."):]
# Patch embeddings
if "embeddings.patch_embeddings" in name:
if "projection.weight" in name:
return "swin.patch_embed.weight"
elif "projection.bias" in name:
return "swin.patch_embed.bias"
elif "norm" in name:
return f"swin.patch_embed.norm.{'weight' if 'weight' in name else 'bias'}"
# Position embeddings
if "position_embeddings" in name:
return "swin.pos_embed"
# Parse layer structure
if "layers." in name:
parts = name.split(".")
stage_idx = None
layer_idx = None
# Find stage and layer indices
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts):
stage_idx = int(parts[i + 1])
if part == "blocks" and i + 1 < len(parts):
layer_idx = int(parts[i + 1])
if stage_idx is not None:
# Layer-specific components
if layer_idx is not None:
# Attention
if "attn.qkv" in name:
suffix = "weight" if "weight" in name else "bias"
return f"swin.stage.{stage_idx}.layer.{layer_idx}.attn.qkv.{suffix}"
elif "attn.proj" in name:
suffix = "weight" if "weight" in name else "bias"
return f"swin.stage.{stage_idx}.layer.{layer_idx}.attn.proj.{suffix}"
# Norms
elif "norm1" in name:
suffix = "weight" if "weight" in name else "bias"
return f"swin.stage.{stage_idx}.layer.{layer_idx}.norm1.{suffix}"
elif "norm2" in name:
suffix = "weight" if "weight" in name else "bias"
return f"swin.stage.{stage_idx}.layer.{layer_idx}.norm2.{suffix}"
# MLP/FFN
elif "mlp.fc1" in name:
suffix = "weight" if "weight" in name else "bias"
return f"swin.stage.{stage_idx}.layer.{layer_idx}.mlp.fc1.{suffix}"
elif "mlp.fc2" in name:
suffix = "weight" if "weight" in name else "bias"
return f"swin.stage.{stage_idx}.layer.{layer_idx}.mlp.fc2.{suffix}"
# Relative position bias
elif "relative_position_bias_table" in name:
return f"swin.stage.{stage_idx}.layer.{layer_idx}.attn.relative_position_bias_table"
# Downsample layers between stages
elif "downsample" in name:
if "norm" in name:
suffix = "weight" if "weight" in name else "bias"
return f"swin.stage.{stage_idx}.downsample.norm.{suffix}"
elif "reduction" in name:
return f"swin.stage.{stage_idx}.downsample.reduction.weight"
# Output normalization
if "layernorm" in name or "layer_norm" in name:
if "final" in name or "output" in name:
suffix = "weight" if "weight" in name else "bias"
return f"swin.norm.{suffix}"
# Default mapping
return f"swin.{name}"
def _map_mbart_tensor_name(self, name: str) -> str:
"""Map HuggingFace mBART tensor names to our convention"""
# Remove model prefix
if name.startswith("model.decoder."):
name = name[len("model.decoder."):]
elif name.startswith("decoder."):
name = name[len("decoder."):]
# Embeddings
if name == "embed_tokens.weight" or name == "shared.weight":
return "token_embd.weight"
elif "embed_positions" in name:
return "position_embd.weight"
# Parse decoder layers
if "layers." in name:
parts = name.split(".")
layer_idx = int(parts[1])
# Self-attention
if "self_attn.q_proj" in name:
return f"blk.{layer_idx}.attn_q.weight"
elif "self_attn.k_proj" in name:
return f"blk.{layer_idx}.attn_k.weight"
elif "self_attn.v_proj" in name:
return f"blk.{layer_idx}.attn_v.weight"
elif "self_attn.out_proj" in name:
return f"blk.{layer_idx}.attn_o.weight"
elif "self_attn_layer_norm" in name:
suffix = "weight" if "weight" in name else "bias"
return f"blk.{layer_idx}.attn_norm.{suffix}"
# Cross-attention (encoder-decoder attention)
elif "encoder_attn.q_proj" in name:
return f"blk.{layer_idx}.attn_q_cross.weight"
elif "encoder_attn.k_proj" in name:
return f"blk.{layer_idx}.attn_k_cross.weight"
elif "encoder_attn.v_proj" in name:
return f"blk.{layer_idx}.attn_v_cross.weight"
elif "encoder_attn.out_proj" in name:
return f"blk.{layer_idx}.attn_o_cross.weight"
elif "encoder_attn_layer_norm" in name:
suffix = "weight" if "weight" in name else "bias"
return f"blk.{layer_idx}.attn_norm_cross.{suffix}"
# FFN
elif "fc1" in name:
return f"blk.{layer_idx}.ffn_up.weight"
elif "fc2" in name:
return f"blk.{layer_idx}.ffn_down.weight"
elif "final_layer_norm" in name:
suffix = "weight" if "weight" in name else "bias"
return f"blk.{layer_idx}.ffn_norm.{suffix}"
# Output layers
elif "layernorm" in name or "layer_norm" in name:
suffix = "weight" if "weight" in name else "bias"
return f"output_norm.{suffix}"
elif "lm_head" in name or "output_projection" in name:
return "output.weight"
# Default mapping
return name
def _map_projector_tensor_name(self, name: str) -> str:
"""Map cross-attention tensors to projector names"""
# Extract layer index from name
if "layers." in name:
parts = name.split(".")
layer_idx = int(parts[1])
if "encoder_attn.q_proj" in name:
return f"mm.layer.{layer_idx}.q_proj.weight"
elif "encoder_attn.k_proj" in name:
return f"mm.layer.{layer_idx}.k_proj.weight"
elif "encoder_attn.v_proj" in name:
return f"mm.layer.{layer_idx}.v_proj.weight"
elif "encoder_attn.out_proj" in name:
return f"mm.layer.{layer_idx}.out_proj.weight"
return f"mm.{name}"
def save_component(self, tensors: Dict[str, torch.Tensor], filename: str, arch_name: str, description: str):
"""Save component tensors to GGUF file"""
output_path = self.output_dir / filename
print(f"Saving {arch_name} to {output_path}...")
writer = gguf.GGUFWriter(str(output_path), arch_name)
writer.add_string("general.name", arch_name)
writer.add_string("general.description", description)
writer.add_string("general.architecture", arch_name.lower())
# Add tensors
for name, tensor in tensors.items():
data = tensor.float().cpu().numpy()
writer.add_tensor(name, data)
writer.write_header_to_file()
writer.write_kv_data_to_file()
writer.write_tensors_to_file()
writer.close()
print(f" Saved {len(tensors)} tensors")
def perform_surgery(self):
"""Main surgery operation - split model into components"""
print("\n" + "=" * 60)
print("Starting Nougat Model Surgery")
print("=" * 60)
# Extract components
vision_tensors = self.extract_vision_encoder()
text_tensors = self.extract_text_decoder()
projector_tensors = self.extract_projector()
# Save components
print("\nSaving components...")
self.save_component(
vision_tensors,
"nougat-vision-swin.gguf",
"Nougat-Vision-Swin",
"Swin Transformer vision encoder from Nougat OCR model"
)
self.save_component(
text_tensors,
"nougat-text-mbart.gguf",
"Nougat-Text-mBART",
"mBART text decoder from Nougat OCR model"
)
if projector_tensors:
self.save_component(
projector_tensors,
"nougat-projector.gguf",
"Nougat-Projector",
"Multimodal projector connecting vision and text models"
)
# Save configuration
self.save_config()
print("\n" + "=" * 60)
print("Surgery Complete!")
print(f"Output files saved to: {self.output_dir}")
print("=" * 60)
def save_config(self):
"""Save model configuration for reconstruction"""
config = {
"model_id": self.model_id,
"vision_config": {
"architecture": "swin",
"image_size": 384,
"patch_size": 4,
"window_size": 7,
"num_channels": 3,
"depths": [2, 2, 6, 2],
"num_heads": [3, 6, 12, 24],
},
"text_config": {
"architecture": "mbart",
"vocab_size": self.processor.tokenizer.vocab_size,
"max_position_embeddings": 1024,
"hidden_size": self.model.config.decoder.hidden_size,
"num_layers": self.model.config.decoder.num_hidden_layers,
"num_attention_heads": self.model.config.decoder.num_attention_heads,
},
"components": {
"vision": "nougat-vision-swin.gguf",
"text": "nougat-text-mbart.gguf",
"projector": "nougat-projector.gguf" if self.extract_projector() else None,
}
}
config_path = self.output_dir / "nougat-config.json"
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
print(f"\nConfiguration saved to {config_path}")
def main():
parser = argparse.ArgumentParser(description="Nougat Model Surgery - Split model into components")
parser.add_argument(
"--model-id",
type=str,
default="facebook/nougat-base",
help="HuggingFace model ID or path to local model"
)
parser.add_argument(
"--output-dir",
type=str,
default="./models/nougat-surgery",
help="Output directory for split components"
)
parser.add_argument(
"--verbose",
action="store_true",
help="Verbose output showing tensor mappings"
)
args = parser.parse_args()
surgeon = NougatModelSurgeon(args.model_id, args.output_dir, args.verbose)
surgeon.perform_surgery()
if __name__ == "__main__":
main()

475
tools/mtmd/swin.cpp Normal file
View File

@ -0,0 +1,475 @@
#include "swin.h"
#include "clip.h"
#include "ggml.h"
#include "ggml-alloc.h"
#include "ggml-backend.h"
#include "gguf.h"
#include <cassert>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <stdexcept>
#include <vector>
#include <array>
// Window partition operation - splits input into non-overlapping windows
struct ggml_tensor * swin_window_partition(struct ggml_context * ctx, struct ggml_tensor * x, int window_size) {
// x shape: [batch_size, height, width, channels]
// output shape: [batch_size * num_windows, window_size, window_size, channels]
int batch_size = x->ne[3];
int H = x->ne[2];
int W = x->ne[1];
int C = x->ne[0];
int nH = H / window_size;
int nW = W / window_size;
// Reshape to [batch_size, nH, window_size, nW, window_size, C]
struct ggml_tensor * reshaped = ggml_reshape_4d(ctx, x,
C * window_size,
window_size * nW,
nH,
batch_size);
// Permute to [batch_size, nH, nW, window_size, window_size, C]
struct ggml_tensor * permuted = ggml_permute(ctx, reshaped, 0, 2, 1, 3);
// Reshape to [batch_size * nH * nW, window_size, window_size, C]
struct ggml_tensor * output = ggml_reshape_4d(ctx, permuted,
C,
window_size,
window_size,
batch_size * nH * nW);
return output;
}
// Window reverse operation - merges windows back to original spatial dimensions
struct ggml_tensor * swin_window_reverse(struct ggml_context * ctx, struct ggml_tensor * windows, int window_size, int H, int W) {
// windows shape: [batch_size * num_windows, window_size, window_size, channels]
// output shape: [batch_size, height, width, channels]
int C = windows->ne[0];
int nH = H / window_size;
int nW = W / window_size;
int batch_size = windows->ne[3] / (nH * nW);
// Reshape to [batch_size, nH, nW, window_size, window_size, C]
struct ggml_tensor * reshaped = ggml_reshape_4d(ctx, windows,
C * window_size * window_size,
nW,
nH,
batch_size);
// Permute to [batch_size, nH, window_size, nW, window_size, C]
struct ggml_tensor * permuted = ggml_permute(ctx, reshaped, 0, 2, 1, 3);
// Reshape to [batch_size, H, W, C]
struct ggml_tensor * output = ggml_reshape_4d(ctx, permuted, C, W, H, batch_size);
return output;
}
// Create attention mask for shifted window attention
struct ggml_tensor * swin_create_window_mask(struct ggml_context * ctx, int window_size, int shift_size, int H, int W) {
if (shift_size == 0) {
return nullptr; // No mask needed for non-shifted windows
}
// Create a mask tensor
struct ggml_tensor * mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, H, W);
// Initialize mask with region indices
float * mask_data = (float *)mask->data;
int h_slices[] = {0, H - window_size, H - shift_size, H};
int w_slices[] = {0, W - window_size, W - shift_size, W};
int cnt = 0;
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
for (int h = h_slices[i]; h < h_slices[i + 1]; h++) {
for (int w = w_slices[j]; w < w_slices[j + 1]; w++) {
mask_data[h * W + w] = cnt;
}
}
cnt++;
}
}
return mask;
}
// Build window attention layer
static struct ggml_tensor * swin_window_attention(
struct ggml_context * ctx,
struct ggml_tensor * x,
const swin_layer & layer,
int num_heads,
int window_size,
bool shifted) {
int batch_size = x->ne[3];
int seq_len = x->ne[2] * x->ne[1]; // window_size * window_size
int hidden_dim = x->ne[0];
int head_dim = hidden_dim / num_heads;
// Reshape input for attention: [batch_size, seq_len, hidden_dim]
x = ggml_reshape_3d(ctx, x, hidden_dim, seq_len, batch_size);
// Layer norm
x = ggml_norm(ctx, x, layer.ln1_w->ne[0]);
x = ggml_add(ctx, ggml_mul(ctx, x, layer.ln1_w), layer.ln1_b);
// QKV projection
struct ggml_tensor * qkv = ggml_mul_mat(ctx, layer.qkv_w, x);
qkv = ggml_add(ctx, qkv, layer.qkv_b);
// Split into Q, K, V
int qkv_dim = qkv->ne[0] / 3;
struct ggml_tensor * q = ggml_view_3d(ctx, qkv, qkv_dim, seq_len, batch_size, qkv->nb[1], qkv->nb[2], 0);
struct ggml_tensor * k = ggml_view_3d(ctx, qkv, qkv_dim, seq_len, batch_size, qkv->nb[1], qkv->nb[2], qkv_dim * ggml_element_size(qkv));
struct ggml_tensor * v = ggml_view_3d(ctx, qkv, qkv_dim, seq_len, batch_size, qkv->nb[1], qkv->nb[2], 2 * qkv_dim * ggml_element_size(qkv));
// Reshape for multi-head attention
q = ggml_reshape_4d(ctx, q, head_dim, num_heads, seq_len, batch_size);
k = ggml_reshape_4d(ctx, k, head_dim, num_heads, seq_len, batch_size);
v = ggml_reshape_4d(ctx, v, head_dim, num_heads, seq_len, batch_size);
// Transpose for attention: [batch_size, num_heads, seq_len, head_dim]
q = ggml_permute(ctx, q, 0, 2, 1, 3);
k = ggml_permute(ctx, k, 0, 2, 1, 3);
v = ggml_permute(ctx, v, 0, 2, 1, 3);
// Scaled dot-product attention
float scale = 1.0f / sqrtf(head_dim);
struct ggml_tensor * attn = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, k)), q);
attn = ggml_scale(ctx, attn, scale);
// Add relative position bias if available
if (layer.relative_position_bias_table != nullptr) {
// This would need proper indexing based on relative positions
// For now, simplified version
attn = ggml_add(ctx, attn, layer.relative_position_bias_table);
}
// Apply mask for shifted window attention
if (shifted) {
// Create and apply attention mask
struct ggml_tensor * mask = swin_create_window_mask(ctx, window_size, window_size / 2,
window_size, window_size);
if (mask != nullptr) {
// Convert mask to attention mask
attn = ggml_add(ctx, attn, mask);
}
}
// Softmax
attn = ggml_soft_max(ctx, attn);
// Apply attention to values
struct ggml_tensor * out = ggml_mul_mat(ctx, v, attn);
// Transpose back: [batch_size, seq_len, num_heads, head_dim]
out = ggml_permute(ctx, out, 0, 2, 1, 3);
// Reshape to merge heads: [batch_size, seq_len, hidden_dim]
out = ggml_reshape_3d(ctx, out, hidden_dim, seq_len, batch_size);
// Output projection
out = ggml_mul_mat(ctx, layer.proj_w, out);
out = ggml_add(ctx, out, layer.proj_b);
return out;
}
// Build FFN layer
static struct ggml_tensor * swin_ffn(
struct ggml_context * ctx,
struct ggml_tensor * x,
const swin_layer & layer,
float mlp_ratio) {
// Layer norm
x = ggml_norm(ctx, x, layer.ln2_w->ne[0]);
x = ggml_add(ctx, ggml_mul(ctx, x, layer.ln2_w), layer.ln2_b);
// FFN: Linear -> GELU -> Linear
x = ggml_mul_mat(ctx, layer.fc1_w, x);
x = ggml_add(ctx, x, layer.fc1_b);
x = ggml_gelu(ctx, x);
x = ggml_mul_mat(ctx, layer.fc2_w, x);
x = ggml_add(ctx, x, layer.fc2_b);
return x;
}
// Build Swin Transformer block
static struct ggml_tensor * swin_block(
struct ggml_context * ctx,
struct ggml_tensor * x,
const swin_layer & layer,
int num_heads,
int window_size,
bool shifted,
float mlp_ratio) {
int H = x->ne[2];
int W = x->ne[1];
struct ggml_tensor * shortcut = x;
// Shifted window partitioning if needed
if (shifted && (H > window_size || W > window_size)) {
// Cyclic shift
int shift_size = window_size / 2;
x = ggml_roll(ctx, x, -shift_size, 2); // Roll along H dimension
x = ggml_roll(ctx, x, -shift_size, 1); // Roll along W dimension
}
// Partition into windows
if (H > window_size || W > window_size) {
x = swin_window_partition(ctx, x, window_size);
}
// Window attention
x = swin_window_attention(ctx, x, layer, num_heads, window_size, shifted);
// Reverse window partition
if (H > window_size || W > window_size) {
x = swin_window_reverse(ctx, x, window_size, H, W);
}
// Reverse cyclic shift if needed
if (shifted && (H > window_size || W > window_size)) {
int shift_size = window_size / 2;
x = ggml_roll(ctx, x, shift_size, 2); // Roll back along H dimension
x = ggml_roll(ctx, x, shift_size, 1); // Roll back along W dimension
}
// Residual connection
x = ggml_add(ctx, x, shortcut);
// FFN with residual
shortcut = x;
x = swin_ffn(ctx, x, layer, mlp_ratio);
x = ggml_add(ctx, x, shortcut);
return x;
}
// Patch merging layer (downsampling)
static struct ggml_tensor * swin_patch_merging(
struct ggml_context * ctx,
struct ggml_tensor * x,
struct ggml_tensor * norm_w,
struct ggml_tensor * norm_b,
struct ggml_tensor * reduction) {
int batch_size = x->ne[3];
int H = x->ne[2];
int W = x->ne[1];
int C = x->ne[0];
// Reshape to merge 2x2 patches
x = ggml_reshape_4d(ctx, x, C, W/2, 2, H/2 * 2 * batch_size);
x = ggml_permute(ctx, x, 0, 2, 1, 3);
x = ggml_reshape_4d(ctx, x, C * 4, W/2, H/2, batch_size);
// Layer norm
x = ggml_norm(ctx, x, norm_w->ne[0]);
x = ggml_add(ctx, ggml_mul(ctx, x, norm_w), norm_b);
// Linear reduction
x = ggml_mul_mat(ctx, reduction, x);
return x;
}
// Build complete Swin Transformer graph
struct ggml_cgraph * swin_build_graph(
struct swin_ctx * ctx,
const swin_image_batch * imgs,
std::pair<int, int> load_image_size,
bool is_inf) {
if (!ctx->has_vision_encoder) {
return nullptr;
}
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
struct ggml_init_params params = {
/*.mem_size =*/ ctx->buf_compute_meta.size(),
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
/*.no_alloc =*/ true,
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph * cgraph = ggml_new_graph(ctx0);
const int batch_size = imgs->size;
const int image_size = hparams.image_size;
const int patch_size = hparams.patch_size;
const int num_patches_side = image_size / patch_size;
const int num_patches = num_patches_side * num_patches_side;
const int hidden_dim = hparams.hidden_dim;
// Input image tensor
struct ggml_tensor * inp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32,
3, image_size, image_size, batch_size);
ggml_set_name(inp, "inp");
// Patch embedding: Conv2D with stride=patch_size
struct ggml_tensor * x = ggml_conv_2d(ctx0, model.patch_embed, inp, patch_size, patch_size, 0, 0, 1, 1);
// Reshape to [batch_size, num_patches, hidden_dim]
x = ggml_reshape_3d(ctx0, x, hidden_dim, num_patches, batch_size);
// Add positional embeddings if available
if (model.pos_embed != nullptr) {
x = ggml_add(ctx0, x, model.pos_embed);
}
// Layer norm after patch embedding
if (model.patch_norm_w != nullptr) {
x = ggml_norm(ctx0, x, model.patch_norm_w->ne[0]);
x = ggml_add(ctx0, ggml_mul(ctx0, x, model.patch_norm_w), model.patch_norm_b);
}
// Reshape for spatial processing
x = ggml_reshape_4d(ctx0, x, hidden_dim, num_patches_side, num_patches_side, batch_size);
// Process through Swin stages
int H = num_patches_side;
int W = num_patches_side;
int C = hidden_dim;
for (size_t stage_idx = 0; stage_idx < model.stages.size(); stage_idx++) {
const auto & stage = model.stages[stage_idx];
// Process layers in this stage
for (size_t layer_idx = 0; layer_idx < stage.layers.size(); layer_idx++) {
const auto & layer = stage.layers[layer_idx];
bool shifted = (layer_idx % 2 == 1); // Alternate between regular and shifted windows
x = swin_block(ctx0, x, layer,
hparams.num_heads[stage_idx],
hparams.window_size,
shifted,
hparams.mlp_ratio);
}
// Patch merging (downsampling) between stages, except for the last stage
if (stage_idx < model.stages.size() - 1 && stage.downsample_reduction != nullptr) {
x = swin_patch_merging(ctx0, x,
stage.downsample_norm_w,
stage.downsample_norm_b,
stage.downsample_reduction);
H /= 2;
W /= 2;
C *= 2; // Channel dimension doubles after patch merging
}
}
// Global average pooling
x = ggml_reshape_3d(ctx0, x, C, H * W, batch_size);
x = ggml_mean(ctx0, x); // Average over spatial dimensions
// Final layer norm
if (model.output_norm_w != nullptr) {
x = ggml_norm(ctx0, x, model.output_norm_w->ne[0]);
x = ggml_add(ctx0, ggml_mul(ctx0, x, model.output_norm_w), model.output_norm_b);
}
ggml_set_name(x, "output");
ggml_build_forward_expand(cgraph, x);
return cgraph;
}
// Model loading function
struct swin_ctx * swin_model_load(const std::string & fname, int verbosity) {
struct swin_ctx * ctx = new swin_ctx();
struct gguf_init_params params = {
/*.no_alloc = */ true,
/*.ctx = */ &ctx->ctx,
};
struct gguf_context * gguf_ctx = gguf_init_from_file(fname.c_str(), params);
if (!gguf_ctx) {
fprintf(stderr, "%s: failed to load model from %s\n", __func__, fname.c_str());
swin_free(ctx);
return nullptr;
}
// Load hyperparameters
auto & hparams = ctx->vision_model.hparams;
// Read Swin-specific parameters from GGUF
const int n_kv = gguf_get_n_kv(gguf_ctx);
for (int i = 0; i < n_kv; ++i) {
const char * key = gguf_get_key(gguf_ctx, i);
if (strcmp(key, KEY_SWIN_WINDOW_SIZE) == 0) {
hparams.window_size = gguf_get_val_i32(gguf_ctx, i);
} else if (strcmp(key, KEY_SWIN_PATCH_SIZE) == 0) {
hparams.patch_size = gguf_get_val_i32(gguf_ctx, i);
} else if (strcmp(key, KEY_SWIN_IMAGE_SIZE) == 0) {
hparams.image_size = gguf_get_val_i32(gguf_ctx, i);
} else if (strcmp(key, KEY_SWIN_HIDDEN_DIM) == 0) {
hparams.hidden_dim = gguf_get_val_i32(gguf_ctx, i);
} else if (strcmp(key, KEY_SWIN_MLP_RATIO) == 0) {
hparams.mlp_ratio = gguf_get_val_f32(gguf_ctx, i);
} else if (strcmp(key, KEY_SWIN_NORM_EPS) == 0) {
hparams.norm_eps = gguf_get_val_f32(gguf_ctx, i);
}
// TODO: Load depths and num_heads arrays
}
ctx->has_vision_encoder = true;
if (verbosity >= 1) {
printf("Swin Transformer model loaded:\n");
printf(" image_size: %d\n", hparams.image_size);
printf(" patch_size: %d\n", hparams.patch_size);
printf(" window_size: %d\n", hparams.window_size);
printf(" hidden_dim: %d\n", hparams.hidden_dim);
printf(" num_stages: %d\n", hparams.num_stages());
}
// TODO: Load actual tensor weights from GGUF file
gguf_free(gguf_ctx);
return ctx;
}
// Free context
void swin_free(struct swin_ctx * ctx) {
if (ctx == nullptr) {
return;
}
if (ctx->backend) {
ggml_backend_free(ctx->backend);
}
if (ctx->params_buffer) {
ggml_backend_buffer_free(ctx->params_buffer);
}
if (ctx->compute_buffer) {
ggml_backend_buffer_free(ctx->compute_buffer);
}
if (ctx->ctx) {
ggml_free(ctx->ctx);
}
delete ctx;
}

153
tools/mtmd/swin.h Normal file
View File

@ -0,0 +1,153 @@
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#include "clip-impl.h"
#include <vector>
#include <string>
// Swin Transformer constants
#define KEY_SWIN_WINDOW_SIZE "swin.window_size"
#define KEY_SWIN_PATCH_SIZE "swin.patch_size"
#define KEY_SWIN_IMAGE_SIZE "swin.image_size"
#define KEY_SWIN_DEPTHS "swin.depths"
#define KEY_SWIN_NUM_HEADS "swin.num_heads"
#define KEY_SWIN_HIDDEN_DIM "swin.hidden_dim"
#define KEY_SWIN_NUM_CHANNELS "swin.num_channels"
#define KEY_SWIN_MLP_RATIO "swin.mlp_ratio"
#define KEY_SWIN_DROP_PATH_RATE "swin.drop_path_rate"
#define KEY_SWIN_NORM_EPS "swin.norm_eps"
// Tensor names for Swin Transformer
#define TN_SWIN_PATCH_EMBED "swin.patch_embed.weight"
#define TN_SWIN_PATCH_NORM "swin.patch_embed.norm.%s"
#define TN_SWIN_POS_EMBED "swin.pos_embed"
#define TN_SWIN_DOWNSAMPLE_NORM "swin.stage.%d.downsample.norm.%s"
#define TN_SWIN_DOWNSAMPLE_PROJ "swin.stage.%d.downsample.reduction.weight"
#define TN_SWIN_ATTN_NORM "swin.stage.%d.layer.%d.norm1.%s"
#define TN_SWIN_ATTN_QKV "swin.stage.%d.layer.%d.attn.qkv.%s"
#define TN_SWIN_ATTN_PROJ "swin.stage.%d.layer.%d.attn.proj.%s"
#define TN_SWIN_ATTN_REL_POS "swin.stage.%d.layer.%d.attn.relative_position_bias_table"
#define TN_SWIN_FFN_NORM "swin.stage.%d.layer.%d.norm2.%s"
#define TN_SWIN_FFN_FC1 "swin.stage.%d.layer.%d.mlp.fc1.%s"
#define TN_SWIN_FFN_FC2 "swin.stage.%d.layer.%d.mlp.fc2.%s"
#define TN_SWIN_OUTPUT_NORM "swin.norm.%s"
// Forward declarations
struct swin_ctx;
// Swin Transformer hyperparameters
struct swin_hparams {
int32_t image_size = 384;
int32_t patch_size = 4;
int32_t num_channels = 3;
int32_t window_size = 7;
int32_t hidden_dim = 96;
std::vector<int32_t> depths = {2, 2, 6, 2}; // depths for each stage
std::vector<int32_t> num_heads = {3, 6, 12, 24}; // number of heads for each stage
float mlp_ratio = 4.0f;
float drop_path_rate = 0.1f;
float norm_eps = 1e-5f;
bool use_checkpoint = false;
// Computed values
int32_t num_stages() const { return depths.size(); }
int32_t num_patches() const { return (image_size / patch_size) * (image_size / patch_size); }
};
// Swin Transformer layer
struct swin_layer {
// Window attention
struct ggml_tensor * ln1_w;
struct ggml_tensor * ln1_b;
struct ggml_tensor * qkv_w;
struct ggml_tensor * qkv_b;
struct ggml_tensor * proj_w;
struct ggml_tensor * proj_b;
struct ggml_tensor * relative_position_bias_table;
// FFN
struct ggml_tensor * ln2_w;
struct ggml_tensor * ln2_b;
struct ggml_tensor * fc1_w;
struct ggml_tensor * fc1_b;
struct ggml_tensor * fc2_w;
struct ggml_tensor * fc2_b;
};
// Swin Transformer stage
struct swin_stage {
std::vector<swin_layer> layers;
// Patch merging (downsample) layer
struct ggml_tensor * downsample_norm_w = nullptr;
struct ggml_tensor * downsample_norm_b = nullptr;
struct ggml_tensor * downsample_reduction = nullptr;
};
// Swin Transformer vision model
struct swin_vision_model {
swin_hparams hparams;
// Patch embedding
struct ggml_tensor * patch_embed;
struct ggml_tensor * patch_norm_w;
struct ggml_tensor * patch_norm_b;
struct ggml_tensor * pos_embed;
// Stages
std::vector<swin_stage> stages;
// Output norm
struct ggml_tensor * output_norm_w;
struct ggml_tensor * output_norm_b;
};
// Main Swin context
struct swin_ctx {
bool has_vision_encoder = false;
bool has_projector = false;
swin_vision_model vision_model;
// Backend and compute
struct ggml_backend * backend = nullptr;
ggml_backend_buffer_t params_buffer = nullptr;
struct ggml_context * ctx = nullptr;
std::vector<uint8_t> buf_compute_meta;
// GGML compute resources
struct ggml_backend_buffer * compute_buffer = nullptr;
struct ggml_context * ctx_compute = nullptr;
struct ggml_alloc * compute_alloc = nullptr;
};
// Public API functions
struct swin_ctx * swin_model_load(const std::string & fname, int verbosity = 1);
void swin_free(struct swin_ctx * ctx);
// Build Swin Transformer graph for inference
struct ggml_cgraph * swin_build_graph(
struct swin_ctx * ctx,
const swin_image_batch * imgs,
std::pair<int, int> load_image_size = {0, 0},
bool is_inf = false);
// Encode image batch
bool swin_image_batch_encode(
struct swin_ctx * ctx,
int n_threads,
const swin_image_batch * imgs,
float * vec);
// Utility functions
int swin_patch_size(const struct swin_ctx * ctx);
bool swin_image_preprocess(struct swin_ctx * ctx, const swin_image_u8 * img, swin_image_f32 * res);
bool swin_image_batch_preprocess(struct swin_ctx * ctx, int n_threads, const swin_image_batch * imgs, swin_image_f32_batch * res_batch);
// Window operations for Swin Transformer
struct ggml_tensor * swin_window_partition(struct ggml_context * ctx, struct ggml_tensor * x, int window_size);
struct ggml_tensor * swin_window_reverse(struct ggml_context * ctx, struct ggml_tensor * windows, int window_size, int H, int W);
struct ggml_tensor * swin_create_window_mask(struct ggml_context * ctx, int window_size, int shift_size, int H, int W);
struct ggml_tensor * swin_compute_mask(struct ggml_context * ctx, int window_size, int shift_size, int H, int W);

539
tools/nougat-cli.cpp Normal file
View File

@ -0,0 +1,539 @@
#include "common.h"
#include "ggml.h"
#include "llama.h"
#include "mtmd/swin.h"
#include <cassert>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include <thread>
#include <chrono>
// External preprocessing function
extern "C" bool nougat_preprocess_pipeline(
const char* input_path,
float** output_data,
int* output_width,
int* output_height,
int* num_pages);
extern "C" void nougat_preprocess_cleanup(float* data);
// CLI arguments structure
struct nougat_params {
std::string input_path = "";
std::string output_path = "";
std::string vision_model = "models/nougat-vision-swin.gguf";
std::string text_model = "models/nougat-text-mbart.gguf";
std::string projector_model = "models/nougat-projector.gguf";
// Processing options
bool batch_mode = false;
int batch_size = 1;
int n_threads = 4;
int n_gpu_layers = 0;
// Output options
std::string output_format = "markdown"; // markdown, latex, plain
bool verbose = false;
bool save_intermediate = false;
// Performance options
bool use_mmap = true;
bool use_flash_attn = false;
int context_size = 2048;
// Document-specific options
bool deskew = true;
bool denoise = true;
bool detect_tables = true;
bool detect_math = true;
int max_pages = -1; // -1 for all pages
};
static void print_usage(const char* prog_name) {
fprintf(stdout, "\n");
fprintf(stdout, "Nougat OCR - Neural Optical Understanding for Academic Documents\n");
fprintf(stdout, "\n");
fprintf(stdout, "Usage: %s [options] -i input_file -o output_file\n", prog_name);
fprintf(stdout, "\n");
fprintf(stdout, "Options:\n");
fprintf(stdout, " -i, --input FILE Input document (PDF, PNG, JPG)\n");
fprintf(stdout, " -o, --output FILE Output file path\n");
fprintf(stdout, " --vision-model FILE Path to vision model GGUF (default: models/nougat-vision-swin.gguf)\n");
fprintf(stdout, " --text-model FILE Path to text model GGUF (default: models/nougat-text-mbart.gguf)\n");
fprintf(stdout, " --projector FILE Path to projector model GGUF (default: models/nougat-projector.gguf)\n");
fprintf(stdout, "\n");
fprintf(stdout, " Processing Options:\n");
fprintf(stdout, " -t, --threads N Number of threads (default: 4)\n");
fprintf(stdout, " -ngl, --n-gpu-layers N Number of layers to offload to GPU (default: 0)\n");
fprintf(stdout, " -b, --batch-size N Batch size for processing (default: 1)\n");
fprintf(stdout, " -c, --context-size N Context size (default: 2048)\n");
fprintf(stdout, " --max-pages N Maximum pages to process (default: all)\n");
fprintf(stdout, "\n");
fprintf(stdout, " Output Options:\n");
fprintf(stdout, " -f, --format FORMAT Output format: markdown, latex, plain (default: markdown)\n");
fprintf(stdout, " -v, --verbose Verbose output\n");
fprintf(stdout, " --save-intermediate Save intermediate processing results\n");
fprintf(stdout, "\n");
fprintf(stdout, " Document Processing:\n");
fprintf(stdout, " --no-deskew Disable automatic deskewing\n");
fprintf(stdout, " --no-denoise Disable denoising\n");
fprintf(stdout, " --no-tables Disable table detection\n");
fprintf(stdout, " --no-math Disable math formula detection\n");
fprintf(stdout, "\n");
fprintf(stdout, " Performance Options:\n");
fprintf(stdout, " --no-mmap Disable memory mapping\n");
fprintf(stdout, " --flash-attn Use flash attention\n");
fprintf(stdout, "\n");
fprintf(stdout, "Examples:\n");
fprintf(stdout, " # Basic OCR of a PDF document\n");
fprintf(stdout, " %s -i paper.pdf -o paper.md\n", prog_name);
fprintf(stdout, "\n");
fprintf(stdout, " # Process with GPU acceleration\n");
fprintf(stdout, " %s -i scan.png -o text.md -ngl 32 -t 8\n", prog_name);
fprintf(stdout, "\n");
fprintf(stdout, " # LaTeX output with math detection\n");
fprintf(stdout, " %s -i math_paper.pdf -o paper.tex -f latex --detect-math\n", prog_name);
fprintf(stdout, "\n");
}
static bool parse_args(int argc, char** argv, nougat_params& params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-i" || arg == "--input") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.input_path = argv[i];
}
else if (arg == "-o" || arg == "--output") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.output_path = argv[i];
}
else if (arg == "--vision-model") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.vision_model = argv[i];
}
else if (arg == "--text-model") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.text_model = argv[i];
}
else if (arg == "--projector") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.projector_model = argv[i];
}
else if (arg == "-t" || arg == "--threads") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.n_threads = std::stoi(argv[i]);
}
else if (arg == "-ngl" || arg == "--n-gpu-layers") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.n_gpu_layers = std::stoi(argv[i]);
}
else if (arg == "-b" || arg == "--batch-size") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.batch_size = std::stoi(argv[i]);
}
else if (arg == "-c" || arg == "--context-size") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.context_size = std::stoi(argv[i]);
}
else if (arg == "--max-pages") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.max_pages = std::stoi(argv[i]);
}
else if (arg == "-f" || arg == "--format") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.output_format = argv[i];
}
else if (arg == "-v" || arg == "--verbose") {
params.verbose = true;
}
else if (arg == "--save-intermediate") {
params.save_intermediate = true;
}
else if (arg == "--no-deskew") {
params.deskew = false;
}
else if (arg == "--no-denoise") {
params.denoise = false;
}
else if (arg == "--no-tables") {
params.detect_tables = false;
}
else if (arg == "--no-math") {
params.detect_math = false;
}
else if (arg == "--no-mmap") {
params.use_mmap = false;
}
else if (arg == "--flash-attn") {
params.use_flash_attn = true;
}
else if (arg == "-h" || arg == "--help") {
print_usage(argv[0]);
exit(0);
}
else {
fprintf(stderr, "Error: Unknown argument '%s'\n", arg.c_str());
return false;
}
}
// Validate required arguments
if (params.input_path.empty()) {
fprintf(stderr, "Error: Input file is required\n");
return false;
}
if (params.output_path.empty()) {
// Generate default output path
size_t dot_pos = params.input_path.find_last_of(".");
params.output_path = params.input_path.substr(0, dot_pos);
if (params.output_format == "markdown") {
params.output_path += ".md";
} else if (params.output_format == "latex") {
params.output_path += ".tex";
} else {
params.output_path += ".txt";
}
}
return true;
}
// Process a single page through the Nougat pipeline
static std::string process_page(
struct swin_ctx* vision_ctx,
struct llama_model* text_model,
struct llama_context* text_ctx,
const float* image_data,
int width,
int height,
const nougat_params& params) {
// Step 1: Encode image with Swin Transformer
if (params.verbose) {
printf("Encoding image with Swin Transformer...\n");
}
// Create image batch
swin_image_f32 img = {
width,
height,
3,
std::vector<float>(image_data, image_data + width * height * 3)
};
swin_image_batch imgs = {1, &img};
// Encode image
std::vector<float> vision_embeddings(2048); // Adjust size based on model
if (!swin_image_batch_encode(vision_ctx, params.n_threads, &imgs, vision_embeddings.data())) {
fprintf(stderr, "Failed to encode image\n");
return "";
}
// Step 2: Pass embeddings through projector
// This would map vision embeddings to text embedding space
// Step 3: Generate text with mBART decoder
if (params.verbose) {
printf("Generating text with mBART decoder...\n");
}
// Create batch for text generation
llama_batch batch = llama_batch_init(params.context_size, 0, 1);
// Set up cross-attention with vision embeddings
// This requires the decoder to attend to encoder outputs
// Start with BOS token
llama_token bos_token = llama_token_get_bos(text_model);
batch.token[0] = bos_token;
batch.pos[0] = 0;
batch.n_seq_id[0] = 1;
batch.seq_id[0][0] = 0;
batch.n_tokens = 1;
// Decode initial token
if (llama_decode(text_ctx, batch) != 0) {
fprintf(stderr, "Failed to decode\n");
llama_batch_free(batch);
return "";
}
// Generate text autoregressively
std::vector<llama_token> generated_tokens;
generated_tokens.push_back(bos_token);
llama_token eos_token = llama_token_get_eos(text_model);
int max_tokens = params.context_size;
for (int i = 1; i < max_tokens; i++) {
// Get logits from last position
float* logits = llama_get_logits_ith(text_ctx, batch.n_tokens - 1);
// Sample next token
int n_vocab = llama_n_vocab(text_model);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
// Sample with top-k and top-p
int top_k = 40;
float top_p = 0.9f;
float temp = 0.8f;
llama_sample_top_k(text_ctx, &candidates_p, top_k, 1);
llama_sample_top_p(text_ctx, &candidates_p, top_p, 1);
llama_sample_temp(text_ctx, &candidates_p, temp);
llama_token new_token = llama_sample_token(text_ctx, &candidates_p);
// Check for EOS
if (new_token == eos_token) {
break;
}
generated_tokens.push_back(new_token);
// Add to batch for next iteration
batch.token[0] = new_token;
batch.pos[0] = i;
batch.n_tokens = 1;
if (llama_decode(text_ctx, batch) != 0) {
fprintf(stderr, "Failed to continue decoding\n");
break;
}
}
llama_batch_free(batch);
// Convert tokens to text
std::string result;
for (auto token : generated_tokens) {
std::string piece = llama_token_to_piece(text_ctx, token, true);
result += piece;
}
return result;
}
int main(int argc, char** argv) {
nougat_params params;
// Parse command line arguments
if (!parse_args(argc, argv, params)) {
print_usage(argv[0]);
return 1;
}
// Print banner
printf("\n");
printf("╔═══════════════════════════════════════════════════════╗\n");
printf("║ Nougat OCR - Document Understanding ║\n");
printf("║ Powered by Swin Transformer + mBART ║\n");
printf("╚═══════════════════════════════════════════════════════╝\n");
printf("\n");
printf("Input: %s\n", params.input_path.c_str());
printf("Output: %s\n", params.output_path.c_str());
printf("Format: %s\n", params.output_format.c_str());
printf("\n");
// Initialize backend
llama_backend_init();
// Load vision model (Swin Transformer)
printf("Loading vision model from %s...\n", params.vision_model.c_str());
struct swin_ctx* vision_ctx = swin_model_load(params.vision_model, params.verbose ? 2 : 1);
if (!vision_ctx) {
fprintf(stderr, "Failed to load vision model\n");
return 1;
}
// Load text model (mBART)
printf("Loading text model from %s...\n", params.text_model.c_str());
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = params.n_gpu_layers;
model_params.use_mmap = params.use_mmap;
struct llama_model* text_model = llama_load_model_from_file(
params.text_model.c_str(), model_params);
if (!text_model) {
fprintf(stderr, "Failed to load text model\n");
swin_free(vision_ctx);
return 1;
}
// Create text generation context
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = params.context_size;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads;
ctx_params.flash_attn = params.use_flash_attn;
struct llama_context* text_ctx = llama_new_context_with_model(text_model, ctx_params);
if (!text_ctx) {
fprintf(stderr, "Failed to create text context\n");
llama_free_model(text_model);
swin_free(vision_ctx);
return 1;
}
// Preprocess document
printf("Preprocessing document...\n");
float* preprocessed_data = nullptr;
int width, height, num_pages;
if (!nougat_preprocess_pipeline(
params.input_path.c_str(),
&preprocessed_data,
&width, &height, &num_pages)) {
fprintf(stderr, "Failed to preprocess document\n");
llama_free(text_ctx);
llama_free_model(text_model);
swin_free(vision_ctx);
return 1;
}
printf("Document info: %d pages, %dx%d pixels\n", num_pages, width, height);
// Limit pages if requested
if (params.max_pages > 0 && num_pages > params.max_pages) {
num_pages = params.max_pages;
printf("Processing first %d pages only\n", num_pages);
}
// Process each page
std::string full_output;
auto start_time = std::chrono::high_resolution_clock::now();
for (int page = 0; page < num_pages; page++) {
printf("\nProcessing page %d/%d...\n", page + 1, num_pages);
float* page_data = preprocessed_data + (page * width * height * 3);
std::string page_text = process_page(
vision_ctx, text_model, text_ctx,
page_data, width, height, params);
if (page_text.empty()) {
fprintf(stderr, "Warning: Failed to process page %d\n", page + 1);
continue;
}
// Add page separator for multi-page documents
if (page > 0) {
if (params.output_format == "markdown") {
full_output += "\n\n---\n\n";
} else if (params.output_format == "latex") {
full_output += "\n\\newpage\n\n";
} else {
full_output += "\n\n[Page " + std::to_string(page + 1) + "]\n\n";
}
}
full_output += page_text;
// Save intermediate results if requested
if (params.save_intermediate) {
std::string intermediate_file = params.output_path + ".page" +
std::to_string(page + 1) + ".tmp";
std::ofstream tmp_out(intermediate_file);
tmp_out << page_text;
tmp_out.close();
}
}
auto end_time = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::seconds>(end_time - start_time);
// Save final output
printf("\nSaving output to %s...\n", params.output_path.c_str());
std::ofstream output_file(params.output_path);
if (!output_file) {
fprintf(stderr, "Failed to open output file\n");
} else {
// Add format-specific headers/footers
if (params.output_format == "latex") {
output_file << "\\documentclass{article}\n";
output_file << "\\usepackage{amsmath}\n";
output_file << "\\usepackage{graphicx}\n";
output_file << "\\begin{document}\n\n";
}
output_file << full_output;
if (params.output_format == "latex") {
output_file << "\n\n\\end{document}\n";
}
output_file.close();
}
// Print statistics
printf("\n");
printf("╔════════════════════════════════════╗\n");
printf("║ OCR Complete! ║\n");
printf("╠════════════════════════════════════╣\n");
printf("║ Pages processed: %-17d ║\n", num_pages);
printf("║ Time taken: %-17lds║\n", duration.count());
printf("║ Output size: %-17zd ║\n", full_output.size());
printf("╚════════════════════════════════════╝\n");
// Cleanup
nougat_preprocess_cleanup(preprocessed_data);
llama_free(text_ctx);
llama_free_model(text_model);
swin_free(vision_ctx);
llama_backend_free();
return 0;
}

View File

@ -0,0 +1,27 @@
set(TARGET nougat-cli)
# Add executable
add_executable(${TARGET} ../nougat-cli.cpp)
# Link with llama library
target_link_libraries(${TARGET} PRIVATE
llama
${CMAKE_THREAD_LIBS_INIT}
)
# Include directories
target_include_directories(${TARGET} PRIVATE
${CMAKE_SOURCE_DIR}/include
${CMAKE_SOURCE_DIR}/tools
)
# Compile flags
llama_add_compile_flags()
# Set output name
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME "nougat-cli")
# Install target
if(LLAMA_INSTALL)
install(TARGETS ${TARGET} RUNTIME DESTINATION bin)
endif()