Merge d3aea508a1 into 2aa45ef9e3
This commit is contained in:
commit
755aeef41c
|
|
@ -0,0 +1,386 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Convert Nougat (Neural Optical Understanding for Academic Documents) model to GGUF format.
|
||||
This script handles the conversion of Nougat's Swin Transformer encoder and mBART decoder.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import NougatProcessor, VisionEncoderDecoderModel
|
||||
|
||||
# Add parent directory to path to import gguf
|
||||
sys.path.append(str(Path(__file__).parent / "gguf-py"))
|
||||
import gguf
|
||||
|
||||
# Constants for Nougat
|
||||
NOUGAT_VISION_PREFIX = "vision_model"
|
||||
NOUGAT_DECODER_PREFIX = "decoder"
|
||||
NOUGAT_ENCODER_PREFIX = "encoder"
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Convert Nougat model to GGUF format")
|
||||
parser.add_argument(
|
||||
"--model-id",
|
||||
type=str,
|
||||
default="facebook/nougat-base",
|
||||
help="HuggingFace model ID or path to local model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./models",
|
||||
help="Output directory for GGUF files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantization",
|
||||
type=str,
|
||||
choices=["f32", "f16", "q8_0", "q4_0", "q4_1"],
|
||||
default="f16",
|
||||
help="Quantization type for model weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split-model",
|
||||
action="store_true",
|
||||
help="Split into separate vision and text GGUF files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocab-only",
|
||||
action="store_true",
|
||||
help="Only export vocabulary/tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Verbose output during conversion",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_tensor_name(name: str) -> str:
|
||||
"""Map Nougat tensor names to GGUF tensor names"""
|
||||
|
||||
# Vision model (Swin Transformer) mappings
|
||||
if name.startswith("encoder.model.encoder."):
|
||||
# Swin encoder layers
|
||||
name = name.replace("encoder.model.encoder.", "swin.")
|
||||
|
||||
# Patch embedding
|
||||
if "embeddings.patch_embeddings" in name:
|
||||
if "projection.weight" in name:
|
||||
return "swin.patch_embed.weight"
|
||||
elif "projection.bias" in name:
|
||||
return "swin.patch_embed.bias"
|
||||
|
||||
# Position embeddings
|
||||
if "position_embeddings" in name:
|
||||
return "swin.pos_embed"
|
||||
|
||||
# Layer mappings
|
||||
if "layers." in name:
|
||||
# Extract stage and layer indices
|
||||
parts = name.split(".")
|
||||
for i, part in enumerate(parts):
|
||||
if part == "layers":
|
||||
stage_idx = int(parts[i + 1])
|
||||
if "blocks." in name:
|
||||
block_idx = int(parts[parts.index("blocks") + 1])
|
||||
|
||||
# Attention components
|
||||
if "attn.qkv" in name:
|
||||
return f"swin.stage.{stage_idx}.layer.{block_idx}.attn.qkv.{'weight' if 'weight' in name else 'bias'}"
|
||||
elif "attn.proj" in name:
|
||||
return f"swin.stage.{stage_idx}.layer.{block_idx}.attn.proj.{'weight' if 'weight' in name else 'bias'}"
|
||||
elif "norm1" in name:
|
||||
return f"swin.stage.{stage_idx}.layer.{block_idx}.norm1.{'weight' if 'weight' in name else 'bias'}"
|
||||
elif "norm2" in name:
|
||||
return f"swin.stage.{stage_idx}.layer.{block_idx}.norm2.{'weight' if 'weight' in name else 'bias'}"
|
||||
elif "mlp.fc1" in name:
|
||||
return f"swin.stage.{stage_idx}.layer.{block_idx}.mlp.fc1.{'weight' if 'weight' in name else 'bias'}"
|
||||
elif "mlp.fc2" in name:
|
||||
return f"swin.stage.{stage_idx}.layer.{block_idx}.mlp.fc2.{'weight' if 'weight' in name else 'bias'}"
|
||||
|
||||
# Downsample layers
|
||||
elif "downsample" in name:
|
||||
if "norm" in name:
|
||||
return f"swin.stage.{stage_idx}.downsample.norm.{'weight' if 'weight' in name else 'bias'}"
|
||||
elif "reduction" in name:
|
||||
return f"swin.stage.{stage_idx}.downsample.reduction.weight"
|
||||
|
||||
# Decoder model (mBART) mappings
|
||||
elif name.startswith("decoder.model."):
|
||||
name = name.replace("decoder.model.", "")
|
||||
|
||||
# Token and position embeddings
|
||||
if name == "shared.weight":
|
||||
return "token_embd.weight"
|
||||
elif name == "decoder.embed_positions.weight":
|
||||
return "position_embd.weight"
|
||||
|
||||
# Decoder layers
|
||||
if "decoder.layers." in name:
|
||||
layer_idx = int(name.split(".")[2])
|
||||
|
||||
# Self-attention
|
||||
if "self_attn.q_proj" in name:
|
||||
return f"blk.{layer_idx}.attn_q.weight"
|
||||
elif "self_attn.k_proj" in name:
|
||||
return f"blk.{layer_idx}.attn_k.weight"
|
||||
elif "self_attn.v_proj" in name:
|
||||
return f"blk.{layer_idx}.attn_v.weight"
|
||||
elif "self_attn.out_proj" in name:
|
||||
return f"blk.{layer_idx}.attn_o.weight"
|
||||
elif "self_attn_layer_norm" in name:
|
||||
return f"blk.{layer_idx}.attn_norm.{'weight' if 'weight' in name else 'bias'}"
|
||||
|
||||
# Cross-attention
|
||||
elif "encoder_attn.q_proj" in name:
|
||||
return f"blk.{layer_idx}.attn_q_cross.weight"
|
||||
elif "encoder_attn.k_proj" in name:
|
||||
return f"blk.{layer_idx}.attn_k_cross.weight"
|
||||
elif "encoder_attn.v_proj" in name:
|
||||
return f"blk.{layer_idx}.attn_v_cross.weight"
|
||||
elif "encoder_attn.out_proj" in name:
|
||||
return f"blk.{layer_idx}.attn_o_cross.weight"
|
||||
elif "encoder_attn_layer_norm" in name:
|
||||
return f"blk.{layer_idx}.attn_norm_cross.{'weight' if 'weight' in name else 'bias'}"
|
||||
|
||||
# FFN
|
||||
elif "fc1" in name:
|
||||
return f"blk.{layer_idx}.ffn_up.weight"
|
||||
elif "fc2" in name:
|
||||
return f"blk.{layer_idx}.ffn_down.weight"
|
||||
elif "final_layer_norm" in name:
|
||||
return f"blk.{layer_idx}.ffn_norm.{'weight' if 'weight' in name else 'bias'}"
|
||||
|
||||
# Output layers
|
||||
elif "decoder.layer_norm" in name:
|
||||
return f"output_norm.{'weight' if 'weight' in name else 'bias'}"
|
||||
elif "lm_head" in name:
|
||||
return "output.weight"
|
||||
|
||||
# Encoder layers (for encoder-only export)
|
||||
elif name.startswith("encoder."):
|
||||
name = name.replace("encoder.", "enc.")
|
||||
# Similar mappings but with enc. prefix
|
||||
return f"enc.{name}"
|
||||
|
||||
# Default: return original name
|
||||
return name
|
||||
|
||||
|
||||
def convert_swin_encoder(model_dict: Dict[str, torch.Tensor], gguf_writer: gguf.GGUFWriter, args):
|
||||
"""Convert Swin Transformer encoder weights to GGUF format"""
|
||||
|
||||
print("Converting Swin Transformer encoder...")
|
||||
|
||||
# Write Swin hyperparameters
|
||||
swin_config = {
|
||||
"window_size": 7,
|
||||
"patch_size": 4,
|
||||
"image_size": 384, # Default for Nougat
|
||||
"hidden_dim": 96,
|
||||
"depths": [2, 2, 6, 2],
|
||||
"num_heads": [3, 6, 12, 24],
|
||||
"mlp_ratio": 4.0,
|
||||
"norm_eps": 1e-5,
|
||||
}
|
||||
|
||||
gguf_writer.add_string("swin.type", "swin_transformer")
|
||||
gguf_writer.add_int32("swin.window_size", swin_config["window_size"])
|
||||
gguf_writer.add_int32("swin.patch_size", swin_config["patch_size"])
|
||||
gguf_writer.add_int32("swin.image_size", swin_config["image_size"])
|
||||
gguf_writer.add_int32("swin.hidden_dim", swin_config["hidden_dim"])
|
||||
gguf_writer.add_float32("swin.mlp_ratio", swin_config["mlp_ratio"])
|
||||
gguf_writer.add_float32("swin.norm_eps", swin_config["norm_eps"])
|
||||
|
||||
# Convert encoder weights
|
||||
encoder_tensors = {k: v for k, v in model_dict.items() if k.startswith("encoder.")}
|
||||
|
||||
for name, tensor in encoder_tensors.items():
|
||||
gguf_name = get_tensor_name(name)
|
||||
|
||||
if args.verbose:
|
||||
print(f" {name} -> {gguf_name} {list(tensor.shape)}")
|
||||
|
||||
# Convert to appropriate dtype
|
||||
if args.quantization == "f32":
|
||||
data = tensor.float().cpu().numpy()
|
||||
elif args.quantization == "f16":
|
||||
data = tensor.half().cpu().numpy()
|
||||
else:
|
||||
# Quantization would be applied here
|
||||
data = tensor.float().cpu().numpy()
|
||||
|
||||
gguf_writer.add_tensor(gguf_name, data)
|
||||
|
||||
print(f" Converted {len(encoder_tensors)} encoder tensors")
|
||||
|
||||
|
||||
def convert_mbart_decoder(model_dict: Dict[str, torch.Tensor], gguf_writer: gguf.GGUFWriter, args):
|
||||
"""Convert mBART decoder weights to GGUF format"""
|
||||
|
||||
print("Converting mBART decoder...")
|
||||
|
||||
# Write mBART architecture info
|
||||
gguf_writer.add_string("general.architecture", "mbart")
|
||||
|
||||
# Convert decoder weights
|
||||
decoder_tensors = {k: v for k, v in model_dict.items() if k.startswith("decoder.")}
|
||||
|
||||
for name, tensor in decoder_tensors.items():
|
||||
gguf_name = get_tensor_name(name)
|
||||
|
||||
if args.verbose:
|
||||
print(f" {name} -> {gguf_name} {list(tensor.shape)}")
|
||||
|
||||
# Convert to appropriate dtype
|
||||
if args.quantization == "f32":
|
||||
data = tensor.float().cpu().numpy()
|
||||
elif args.quantization == "f16":
|
||||
data = tensor.half().cpu().numpy()
|
||||
else:
|
||||
# Quantization would be applied here
|
||||
data = tensor.float().cpu().numpy()
|
||||
|
||||
gguf_writer.add_tensor(gguf_name, data)
|
||||
|
||||
print(f" Converted {len(decoder_tensors)} decoder tensors")
|
||||
|
||||
|
||||
def convert_tokenizer(processor, gguf_writer: gguf.GGUFWriter, args):
|
||||
"""Convert Nougat tokenizer/processor to GGUF format"""
|
||||
|
||||
print("Converting tokenizer...")
|
||||
|
||||
tokenizer = processor.tokenizer
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
# Write tokenizer metadata
|
||||
gguf_writer.add_string("tokenizer.model", "mbart")
|
||||
gguf_writer.add_int32("tokenizer.vocab_size", len(vocab))
|
||||
|
||||
# Add special tokens
|
||||
special_tokens = {
|
||||
"bos": tokenizer.bos_token,
|
||||
"eos": tokenizer.eos_token,
|
||||
"unk": tokenizer.unk_token,
|
||||
"pad": tokenizer.pad_token,
|
||||
}
|
||||
|
||||
for key, token in special_tokens.items():
|
||||
if token:
|
||||
gguf_writer.add_string(f"tokenizer.{key}_token", token)
|
||||
gguf_writer.add_int32(f"tokenizer.{key}_token_id", tokenizer.convert_tokens_to_ids(token))
|
||||
|
||||
# Add vocabulary
|
||||
tokens = []
|
||||
scores = []
|
||||
token_types = []
|
||||
|
||||
for token, token_id in sorted(vocab.items(), key=lambda x: x[1]):
|
||||
tokens.append(token.encode("utf-8"))
|
||||
scores.append(0.0) # Dummy scores for now
|
||||
token_types.append(1 if token in tokenizer.all_special_tokens else 0)
|
||||
|
||||
gguf_writer.add_token_list(tokens)
|
||||
gguf_writer.add_token_scores(scores)
|
||||
gguf_writer.add_token_types(token_types)
|
||||
|
||||
print(f" Vocabulary size: {len(vocab)}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Create output directory
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Loading Nougat model from {args.model_id}...")
|
||||
|
||||
# Load model and processor
|
||||
processor = NougatProcessor.from_pretrained(args.model_id)
|
||||
model = VisionEncoderDecoderModel.from_pretrained(args.model_id)
|
||||
|
||||
# Get model state dict
|
||||
state_dict = model.state_dict()
|
||||
|
||||
if args.split_model:
|
||||
# Create separate files for vision and text models
|
||||
|
||||
# Vision model (Swin encoder)
|
||||
vision_output = output_dir / "nougat-vision.gguf"
|
||||
print(f"\nCreating vision model: {vision_output}")
|
||||
|
||||
vision_writer = gguf.GGUFWriter(str(vision_output), "nougat-vision")
|
||||
vision_writer.add_string("general.name", "Nougat Vision Model (Swin)")
|
||||
vision_writer.add_string("general.description", "Swin Transformer encoder for Nougat OCR")
|
||||
vision_writer.add_string("general.architecture", "swin")
|
||||
|
||||
convert_swin_encoder(state_dict, vision_writer, args)
|
||||
vision_writer.write_header_to_file()
|
||||
vision_writer.write_kv_data_to_file()
|
||||
vision_writer.write_tensors_to_file()
|
||||
vision_writer.close()
|
||||
|
||||
# Text model (mBART decoder)
|
||||
text_output = output_dir / "nougat-text.gguf"
|
||||
print(f"\nCreating text model: {text_output}")
|
||||
|
||||
text_writer = gguf.GGUFWriter(str(text_output), "nougat-text")
|
||||
text_writer.add_string("general.name", "Nougat Text Model (mBART)")
|
||||
text_writer.add_string("general.description", "mBART decoder for Nougat OCR")
|
||||
|
||||
convert_mbart_decoder(state_dict, text_writer, args)
|
||||
convert_tokenizer(processor, text_writer, args)
|
||||
|
||||
text_writer.write_header_to_file()
|
||||
text_writer.write_kv_data_to_file()
|
||||
text_writer.write_tensors_to_file()
|
||||
text_writer.close()
|
||||
|
||||
else:
|
||||
# Create single combined model file
|
||||
output_file = output_dir / "nougat-combined.gguf"
|
||||
print(f"\nCreating combined model: {output_file}")
|
||||
|
||||
writer = gguf.GGUFWriter(str(output_file), "nougat")
|
||||
writer.add_string("general.name", "Nougat OCR Model")
|
||||
writer.add_string("general.description", "Neural Optical Understanding for Academic Documents")
|
||||
writer.add_string("general.architecture", "nougat")
|
||||
|
||||
# Add both encoder and decoder
|
||||
convert_swin_encoder(state_dict, writer, args)
|
||||
convert_mbart_decoder(state_dict, writer, args)
|
||||
|
||||
if not args.vocab_only:
|
||||
convert_tokenizer(processor, writer, args)
|
||||
|
||||
writer.write_header_to_file()
|
||||
writer.write_kv_data_to_file()
|
||||
writer.write_tensors_to_file()
|
||||
writer.close()
|
||||
|
||||
print("\nConversion complete!")
|
||||
|
||||
# Print model statistics
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"\nModel statistics:")
|
||||
print(f" Total parameters: {total_params:,}")
|
||||
print(f" Encoder parameters: {sum(p.numel() for n, p in model.named_parameters() if 'encoder' in n):,}")
|
||||
print(f" Decoder parameters: {sum(p.numel() for n, p in model.named_parameters() if 'decoder' in n):,}")
|
||||
|
||||
if args.quantization != "f32":
|
||||
print(f" Quantization: {args.quantization}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -410,6 +410,8 @@ class MODEL_ARCH(IntEnum):
|
|||
BITNET = auto()
|
||||
T5 = auto()
|
||||
T5ENCODER = auto()
|
||||
MBART = auto()
|
||||
MBARTENCODER = auto()
|
||||
JAIS = auto()
|
||||
NEMOTRON = auto()
|
||||
NEMOTRON_H = auto()
|
||||
|
|
@ -784,6 +786,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.BITNET: "bitnet",
|
||||
MODEL_ARCH.T5: "t5",
|
||||
MODEL_ARCH.T5ENCODER: "t5encoder",
|
||||
MODEL_ARCH.MBART: "mbart",
|
||||
MODEL_ARCH.MBARTENCODER: "mbartencoder",
|
||||
MODEL_ARCH.JAIS: "jais",
|
||||
MODEL_ARCH.NEMOTRON: "nemotron",
|
||||
MODEL_ARCH.NEMOTRON_H: "nemotron_h",
|
||||
|
|
@ -2485,6 +2489,48 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.ENC_FFN_UP,
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM,
|
||||
],
|
||||
MODEL_ARCH.MBART: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.POS_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_NORM_CROSS,
|
||||
MODEL_TENSOR.ATTN_Q_CROSS,
|
||||
MODEL_TENSOR.ATTN_K_CROSS,
|
||||
MODEL_TENSOR.ATTN_V_CROSS,
|
||||
MODEL_TENSOR.ATTN_OUT_CROSS,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.ENC_ATTN_NORM,
|
||||
MODEL_TENSOR.ENC_ATTN_Q,
|
||||
MODEL_TENSOR.ENC_ATTN_K,
|
||||
MODEL_TENSOR.ENC_ATTN_V,
|
||||
MODEL_TENSOR.ENC_ATTN_OUT,
|
||||
MODEL_TENSOR.ENC_FFN_NORM,
|
||||
MODEL_TENSOR.ENC_FFN_DOWN,
|
||||
MODEL_TENSOR.ENC_FFN_UP,
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM,
|
||||
],
|
||||
MODEL_ARCH.MBARTENCODER: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.POS_EMBD,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ENC_ATTN_NORM,
|
||||
MODEL_TENSOR.ENC_ATTN_Q,
|
||||
MODEL_TENSOR.ENC_ATTN_K,
|
||||
MODEL_TENSOR.ENC_ATTN_V,
|
||||
MODEL_TENSOR.ENC_ATTN_OUT,
|
||||
MODEL_TENSOR.ENC_FFN_NORM,
|
||||
MODEL_TENSOR.ENC_FFN_DOWN,
|
||||
MODEL_TENSOR.ENC_FFN_UP,
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM,
|
||||
],
|
||||
MODEL_ARCH.JAIS: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
|
|
|||
|
|
@ -72,6 +72,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_BITNET, "bitnet" },
|
||||
{ LLM_ARCH_T5, "t5" },
|
||||
{ LLM_ARCH_T5ENCODER, "t5encoder" },
|
||||
{ LLM_ARCH_MBART, "mbart" },
|
||||
{ LLM_ARCH_MBARTENCODER, "mbartencoder" },
|
||||
{ LLM_ARCH_JAIS, "jais" },
|
||||
{ LLM_ARCH_NEMOTRON, "nemotron" },
|
||||
{ LLM_ARCH_NEMOTRON_H, "nemotron_h" },
|
||||
|
|
@ -1706,6 +1708,54 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_MBART,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_POS_EMBD, "position_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_o" },
|
||||
{ LLM_TENSOR_ATTN_NORM_CROSS, "blk.%d.attn_norm_cross" },
|
||||
{ LLM_TENSOR_ATTN_Q_CROSS, "blk.%d.attn_q_cross" },
|
||||
{ LLM_TENSOR_ATTN_K_CROSS, "blk.%d.attn_k_cross" },
|
||||
{ LLM_TENSOR_ATTN_V_CROSS, "blk.%d.attn_v_cross" },
|
||||
{ LLM_TENSOR_ATTN_OUT_CROSS, "blk.%d.attn_o_cross" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" },
|
||||
{ LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_MBARTENCODER,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_POS_EMBD, "position_embd" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
|
||||
{ LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" },
|
||||
{ LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_JAIS,
|
||||
{
|
||||
|
|
|
|||
|
|
@ -76,6 +76,8 @@ enum llm_arch {
|
|||
LLM_ARCH_BITNET,
|
||||
LLM_ARCH_T5,
|
||||
LLM_ARCH_T5ENCODER,
|
||||
LLM_ARCH_MBART,
|
||||
LLM_ARCH_MBARTENCODER,
|
||||
LLM_ARCH_JAIS,
|
||||
LLM_ARCH_NEMOTRON,
|
||||
LLM_ARCH_NEMOTRON_H,
|
||||
|
|
@ -326,6 +328,11 @@ enum llm_tensor {
|
|||
LLM_TENSOR_ATTN_OUT,
|
||||
LLM_TENSOR_ATTN_NORM,
|
||||
LLM_TENSOR_ATTN_NORM_2,
|
||||
LLM_TENSOR_ATTN_NORM_CROSS,
|
||||
LLM_TENSOR_ATTN_Q_CROSS,
|
||||
LLM_TENSOR_ATTN_K_CROSS,
|
||||
LLM_TENSOR_ATTN_V_CROSS,
|
||||
LLM_TENSOR_ATTN_OUT_CROSS,
|
||||
LLM_TENSOR_ATTN_OUT_NORM,
|
||||
LLM_TENSOR_ATTN_POST_NORM,
|
||||
LLM_TENSOR_ATTN_ROT_EMBD,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,162 @@
|
|||
#include "models.h"
|
||||
|
||||
llm_build_mbart_dec::llm_build_mbart_dec(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
// mBART uses learned positional embeddings
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// Add positional embeddings
|
||||
ggml_tensor * pos_embd = build_inp_pos_embd();
|
||||
if (pos_embd) {
|
||||
inpL = ggml_add(ctx0, inpL, pos_embd);
|
||||
cb(inpL, "pos_embd", -1);
|
||||
}
|
||||
|
||||
// Get encoder embeddings for cross-attention
|
||||
ggml_tensor * embd_enc = build_inp_cross_embd();
|
||||
const int64_t n_outputs_enc = embd_enc->ne[1];
|
||||
|
||||
// Layer normalization before the first layer (mBART characteristic)
|
||||
cur = build_norm(inpL,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM, -1);
|
||||
cb(cur, "input_norm", -1);
|
||||
inpL = cur;
|
||||
|
||||
auto * inp_attn_self = build_attn_inp_kv();
|
||||
auto * inp_attn_cross = build_attn_inp_cross();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
const int64_t dec_n_layer = hparams.dec_n_layer;
|
||||
|
||||
for (int il = 0; il < dec_n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// norm before attention
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// mBART uses standard scaled dot-product attention
|
||||
cur = build_attn(inp_attn_self,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf((float)n_embd_head), il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
// residual connection
|
||||
cur = ggml_add(ctx0, cur, inpSA);
|
||||
cb(cur, "self_attn_out", il);
|
||||
|
||||
ggml_tensor * inpCA = cur;
|
||||
|
||||
// cross-attention
|
||||
{
|
||||
// norm before cross-attention
|
||||
cur = build_norm(cur,
|
||||
model.layers[il].attn_norm_cross, NULL,
|
||||
LLM_NORM, il);
|
||||
cb(cur, "attn_norm_cross", il);
|
||||
|
||||
// Q from decoder, K and V from encoder
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_cross, cur);
|
||||
cb(Qcur, "Qcur_cross", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_cross, embd_enc);
|
||||
cb(Kcur, "Kcur_cross", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_cross, embd_enc);
|
||||
cb(Vcur, "Vcur_cross", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_outputs_enc);
|
||||
|
||||
cur = build_attn(inp_attn_cross,
|
||||
model.layers[il].wo_cross, model.layers[il].bo_cross,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf((float)n_embd_head), il);
|
||||
cb(cur, "kqv_cross_out", il);
|
||||
}
|
||||
|
||||
if (il == dec_n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
|
||||
}
|
||||
|
||||
// residual connection
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA);
|
||||
cb(ffn_inp, "cross_attn_out", il);
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
// norm before FFN
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// mBART uses GELU activation
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
NULL, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_GELU,
|
||||
LLM_FFN_SEQ,
|
||||
il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
// residual connection
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "layer_out", il);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
cb(cur, "result_embd", -1);
|
||||
|
||||
// Final layer normalization
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head for generation
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
#include "models.h"
|
||||
|
||||
llm_build_mbart_enc::llm_build_mbart_enc(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
// mBART uses learned positional embeddings
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// Add positional embeddings for mBART
|
||||
ggml_tensor * pos_embd = build_inp_pos_embd();
|
||||
if (pos_embd) {
|
||||
inpL = ggml_add(ctx0, inpL, pos_embd);
|
||||
cb(inpL, "pos_embd", -1);
|
||||
}
|
||||
|
||||
// Layer normalization before the first layer (mBART characteristic)
|
||||
cur = build_norm(inpL,
|
||||
model.output_norm_enc, NULL,
|
||||
LLM_NORM, -1);
|
||||
cb(cur, "input_norm", -1);
|
||||
inpL = cur;
|
||||
|
||||
auto * inp_attn = build_attn_inp_no_cache();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// self-attention (mBART uses pre-norm)
|
||||
{
|
||||
// norm before attention
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm_enc, NULL,
|
||||
LLM_NORM, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_enc, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_enc, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_enc, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// mBART uses standard scaled dot-product attention without relative position bias
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo_enc, model.layers[il].bo_enc,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf((float)n_embd_head), il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
// residual connection
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "attn_out", il);
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
// norm before FFN
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm_enc, NULL,
|
||||
LLM_NORM, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// mBART uses GELU activation
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up_enc, NULL, NULL,
|
||||
NULL, NULL, NULL,
|
||||
model.layers[il].ffn_down_enc, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_GELU,
|
||||
LLM_FFN_SEQ,
|
||||
il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
// residual connection
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
cb(cur, "result_embd", -1);
|
||||
|
||||
// Final layer normalization
|
||||
cur = build_norm(cur,
|
||||
model.output_norm_enc, NULL,
|
||||
LLM_NORM, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
|
@ -535,6 +535,14 @@ struct llm_build_t5_enc : public llm_graph_context {
|
|||
llm_build_t5_enc(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_mbart_enc : public llm_graph_context {
|
||||
llm_build_mbart_enc(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_mbart_dec : public llm_graph_context {
|
||||
llm_build_mbart_dec(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_wavtokenizer_dec : public llm_graph_context {
|
||||
llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ else()
|
|||
add_subdirectory(tokenize)
|
||||
add_subdirectory(tts)
|
||||
add_subdirectory(mtmd)
|
||||
add_subdirectory(nougat)
|
||||
if (GGML_RPC)
|
||||
add_subdirectory(rpc)
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,386 @@
|
|||
#include "clip.h"
|
||||
#include "swin.h"
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
|
||||
// External image loading library integration (stb_image)
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb_image.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#else
|
||||
#include <sys/stat.h>
|
||||
#endif
|
||||
|
||||
// Document-specific preprocessing parameters for Nougat
|
||||
struct nougat_preprocess_params {
|
||||
int target_width = 896; // Nougat uses different resolution than standard vision models
|
||||
int target_height = 1344; // Optimized for document aspect ratio
|
||||
float mean[3] = {0.485f, 0.456f, 0.406f}; // ImageNet normalization
|
||||
float std[3] = {0.229f, 0.224f, 0.225f};
|
||||
bool center_crop = false; // Documents should not be center-cropped
|
||||
bool maintain_aspect = true; // Important for documents
|
||||
int patch_size = 4; // Swin Transformer patch size
|
||||
};
|
||||
|
||||
// Structure to hold document metadata
|
||||
struct document_metadata {
|
||||
int original_width;
|
||||
int original_height;
|
||||
int num_pages;
|
||||
std::string format; // PDF, PNG, JPG, etc.
|
||||
float dpi;
|
||||
};
|
||||
|
||||
// Preprocess a single document image for Nougat
|
||||
static bool preprocess_document_image(
|
||||
const uint8_t* img_data,
|
||||
int width,
|
||||
int height,
|
||||
int channels,
|
||||
const nougat_preprocess_params& params,
|
||||
std::vector<float>& output) {
|
||||
|
||||
// Calculate scaling to fit target dimensions while maintaining aspect ratio
|
||||
float scale_w = static_cast<float>(params.target_width) / width;
|
||||
float scale_h = static_cast<float>(params.target_height) / height;
|
||||
float scale = params.maintain_aspect ? std::min(scale_w, scale_h) : 1.0f;
|
||||
|
||||
int new_width = static_cast<int>(width * scale);
|
||||
int new_height = static_cast<int>(height * scale);
|
||||
|
||||
// Ensure dimensions are divisible by patch size
|
||||
new_width = (new_width / params.patch_size) * params.patch_size;
|
||||
new_height = (new_height / params.patch_size) * params.patch_size;
|
||||
|
||||
// Resize image using bilinear interpolation
|
||||
std::vector<uint8_t> resized_img(new_width * new_height * 3);
|
||||
|
||||
for (int y = 0; y < new_height; y++) {
|
||||
for (int x = 0; x < new_width; x++) {
|
||||
float src_x = x / scale;
|
||||
float src_y = y / scale;
|
||||
|
||||
int x0 = static_cast<int>(src_x);
|
||||
int y0 = static_cast<int>(src_y);
|
||||
int x1 = std::min(x0 + 1, width - 1);
|
||||
int y1 = std::min(y0 + 1, height - 1);
|
||||
|
||||
float fx = src_x - x0;
|
||||
float fy = src_y - y0;
|
||||
|
||||
for (int c = 0; c < 3; c++) {
|
||||
float v00 = img_data[(y0 * width + x0) * channels + c];
|
||||
float v10 = img_data[(y0 * width + x1) * channels + c];
|
||||
float v01 = img_data[(y1 * width + x0) * channels + c];
|
||||
float v11 = img_data[(y1 * width + x1) * channels + c];
|
||||
|
||||
float v0 = v00 * (1 - fx) + v10 * fx;
|
||||
float v1 = v01 * (1 - fx) + v11 * fx;
|
||||
float v = v0 * (1 - fy) + v1 * fy;
|
||||
|
||||
resized_img[(y * new_width + x) * 3 + c] = static_cast<uint8_t>(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pad to target size if needed
|
||||
int pad_left = (params.target_width - new_width) / 2;
|
||||
int pad_top = (params.target_height - new_height) / 2;
|
||||
|
||||
output.resize(params.target_width * params.target_height * 3);
|
||||
|
||||
// Initialize with padding (white background for documents)
|
||||
std::fill(output.begin(), output.end(), 1.0f);
|
||||
|
||||
// Copy resized image to output with normalization
|
||||
for (int y = 0; y < new_height; y++) {
|
||||
for (int x = 0; x < new_width; x++) {
|
||||
int out_x = x + pad_left;
|
||||
int out_y = y + pad_top;
|
||||
|
||||
if (out_x >= 0 && out_x < params.target_width &&
|
||||
out_y >= 0 && out_y < params.target_height) {
|
||||
|
||||
for (int c = 0; c < 3; c++) {
|
||||
float pixel = resized_img[(y * new_width + x) * 3 + c] / 255.0f;
|
||||
pixel = (pixel - params.mean[c]) / params.std[c];
|
||||
output[(out_y * params.target_width + out_x) * 3 + c] = pixel;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Load and preprocess a document file (supports various formats)
|
||||
bool nougat_preprocess_document_file(
|
||||
const std::string& filename,
|
||||
nougat_preprocess_params& params,
|
||||
std::vector<std::vector<float>>& page_outputs,
|
||||
document_metadata& metadata) {
|
||||
|
||||
// Check file extension
|
||||
std::string ext = filename.substr(filename.find_last_of(".") + 1);
|
||||
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
|
||||
|
||||
metadata.format = ext;
|
||||
|
||||
if (ext == "pdf") {
|
||||
// PDF processing would require a PDF library like poppler or mupdf
|
||||
// For now, we'll return an error for PDF files
|
||||
fprintf(stderr, "PDF processing not yet implemented. Please convert to image format.\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Load image using stb_image
|
||||
int width, height, channels;
|
||||
unsigned char* img_data = stbi_load(filename.c_str(), &width, &height, &channels, 3);
|
||||
|
||||
if (!img_data) {
|
||||
fprintf(stderr, "Failed to load image: %s\n", filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
metadata.original_width = width;
|
||||
metadata.original_height = height;
|
||||
metadata.num_pages = 1; // Single image
|
||||
metadata.dpi = 300.0f; // Assume standard document DPI
|
||||
|
||||
// Preprocess the image
|
||||
std::vector<float> output;
|
||||
bool success = preprocess_document_image(
|
||||
img_data, width, height, 3, params, output);
|
||||
|
||||
if (success) {
|
||||
page_outputs.push_back(output);
|
||||
}
|
||||
|
||||
stbi_image_free(img_data);
|
||||
return success;
|
||||
}
|
||||
|
||||
// Batch preprocessing for multiple document pages
|
||||
bool nougat_preprocess_document_batch(
|
||||
const std::vector<std::string>& filenames,
|
||||
nougat_preprocess_params& params,
|
||||
std::vector<std::vector<float>>& outputs) {
|
||||
|
||||
outputs.clear();
|
||||
outputs.reserve(filenames.size());
|
||||
|
||||
for (const auto& filename : filenames) {
|
||||
document_metadata metadata;
|
||||
std::vector<std::vector<float>> page_outputs;
|
||||
|
||||
if (!nougat_preprocess_document_file(filename, params, page_outputs, metadata)) {
|
||||
fprintf(stderr, "Failed to preprocess: %s\n", filename.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
// Add all pages from this document
|
||||
for (auto& page : page_outputs) {
|
||||
outputs.push_back(std::move(page));
|
||||
}
|
||||
}
|
||||
|
||||
return !outputs.empty();
|
||||
}
|
||||
|
||||
// Apply document-specific augmentations
|
||||
void nougat_augment_document(
|
||||
std::vector<float>& image_data,
|
||||
int width,
|
||||
int height,
|
||||
bool random_rotation = false,
|
||||
bool deskew = true,
|
||||
bool denoise = true) {
|
||||
|
||||
// Document deskewing (straighten tilted scans)
|
||||
if (deskew) {
|
||||
// Simplified deskew - would need proper implementation
|
||||
// using Hough transform or similar technique
|
||||
}
|
||||
|
||||
// Denoising for scanned documents
|
||||
if (denoise) {
|
||||
// Apply median filter or similar denoising
|
||||
// Simplified implementation
|
||||
std::vector<float> temp = image_data;
|
||||
|
||||
for (int y = 1; y < height - 1; y++) {
|
||||
for (int x = 1; x < width - 1; x++) {
|
||||
for (int c = 0; c < 3; c++) {
|
||||
std::vector<float> neighborhood;
|
||||
|
||||
// Collect 3x3 neighborhood
|
||||
for (int dy = -1; dy <= 1; dy++) {
|
||||
for (int dx = -1; dx <= 1; dx++) {
|
||||
int idx = ((y + dy) * width + (x + dx)) * 3 + c;
|
||||
neighborhood.push_back(temp[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
// Median filter
|
||||
std::sort(neighborhood.begin(), neighborhood.end());
|
||||
image_data[(y * width + x) * 3 + c] = neighborhood[4];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Random rotation for augmentation during training
|
||||
if (random_rotation) {
|
||||
// Apply small random rotation (-5 to +5 degrees)
|
||||
// Would need proper rotation implementation
|
||||
}
|
||||
}
|
||||
|
||||
// Extract text regions from document for focused processing
|
||||
struct text_region {
|
||||
int x, y, width, height;
|
||||
float confidence;
|
||||
};
|
||||
|
||||
std::vector<text_region> nougat_detect_text_regions(
|
||||
const std::vector<float>& image_data,
|
||||
int width,
|
||||
int height) {
|
||||
|
||||
std::vector<text_region> regions;
|
||||
|
||||
// Simple text detection based on connected components
|
||||
// This would need a proper implementation using:
|
||||
// - Edge detection
|
||||
// - Connected component analysis
|
||||
// - Text/non-text classification
|
||||
|
||||
// For now, return the whole image as a single region
|
||||
text_region full_page;
|
||||
full_page.x = 0;
|
||||
full_page.y = 0;
|
||||
full_page.width = width;
|
||||
full_page.height = height;
|
||||
full_page.confidence = 1.0f;
|
||||
|
||||
regions.push_back(full_page);
|
||||
|
||||
return regions;
|
||||
}
|
||||
|
||||
// Enhanced preprocessing for mathematical formulas
|
||||
void nougat_preprocess_math_regions(
|
||||
std::vector<float>& image_data,
|
||||
int width,
|
||||
int height,
|
||||
const std::vector<text_region>& math_regions) {
|
||||
|
||||
// Apply special preprocessing for mathematical content
|
||||
for (const auto& region : math_regions) {
|
||||
// Enhance contrast for mathematical symbols
|
||||
for (int y = region.y; y < region.y + region.height; y++) {
|
||||
for (int x = region.x; x < region.x + region.width; x++) {
|
||||
for (int c = 0; c < 3; c++) {
|
||||
int idx = (y * width + x) * 3 + c;
|
||||
float& pixel = image_data[idx];
|
||||
|
||||
// Increase contrast
|
||||
pixel = (pixel - 0.5f) * 1.2f + 0.5f;
|
||||
pixel = std::max(0.0f, std::min(1.0f, pixel));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Table detection and preprocessing
|
||||
struct table_region {
|
||||
text_region bounds;
|
||||
int rows, cols;
|
||||
std::vector<text_region> cells;
|
||||
};
|
||||
|
||||
std::vector<table_region> nougat_detect_tables(
|
||||
const std::vector<float>& image_data,
|
||||
int width,
|
||||
int height) {
|
||||
|
||||
std::vector<table_region> tables;
|
||||
|
||||
// Table detection would require:
|
||||
// - Line detection (horizontal and vertical)
|
||||
// - Grid structure analysis
|
||||
// - Cell boundary detection
|
||||
|
||||
// Placeholder implementation
|
||||
return tables;
|
||||
}
|
||||
|
||||
// Main preprocessing pipeline for Nougat OCR
|
||||
extern "C" bool nougat_preprocess_pipeline(
|
||||
const char* input_path,
|
||||
float** output_data,
|
||||
int* output_width,
|
||||
int* output_height,
|
||||
int* num_pages) {
|
||||
|
||||
nougat_preprocess_params params;
|
||||
std::vector<std::vector<float>> page_outputs;
|
||||
document_metadata metadata;
|
||||
|
||||
// Load and preprocess document
|
||||
if (!nougat_preprocess_document_file(
|
||||
input_path, params, page_outputs, metadata)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Apply document-specific processing
|
||||
for (auto& page : page_outputs) {
|
||||
// Detect text regions
|
||||
auto text_regions = nougat_detect_text_regions(
|
||||
page, params.target_width, params.target_height);
|
||||
|
||||
// Apply augmentations
|
||||
nougat_augment_document(
|
||||
page, params.target_width, params.target_height,
|
||||
false, true, true);
|
||||
|
||||
// Detect and process mathematical regions
|
||||
// (would need actual math detection)
|
||||
// nougat_preprocess_math_regions(page, width, height, math_regions);
|
||||
}
|
||||
|
||||
// Prepare output
|
||||
if (!page_outputs.empty()) {
|
||||
*output_width = params.target_width;
|
||||
*output_height = params.target_height;
|
||||
*num_pages = page_outputs.size();
|
||||
|
||||
// Allocate and copy data
|
||||
size_t total_size = params.target_width * params.target_height * 3 * page_outputs.size();
|
||||
*output_data = new float[total_size];
|
||||
|
||||
size_t offset = 0;
|
||||
for (const auto& page : page_outputs) {
|
||||
std::copy(page.begin(), page.end(), *output_data + offset);
|
||||
offset += page.size();
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// Cleanup function
|
||||
extern "C" void nougat_preprocess_cleanup(float* data) {
|
||||
delete[] data;
|
||||
}
|
||||
|
|
@ -0,0 +1,400 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Nougat Model Surgery Script
|
||||
Splits the Nougat model into separate vision encoder (Swin) and text decoder (mBART) components.
|
||||
Also creates the multimodal projector that connects them.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import VisionEncoderDecoderModel, NougatProcessor
|
||||
|
||||
# Add parent directory to import gguf
|
||||
sys.path.append(str(Path(__file__).parent.parent / "gguf-py"))
|
||||
import gguf
|
||||
|
||||
|
||||
class NougatModelSurgeon:
|
||||
"""Handles splitting and converting Nougat model components"""
|
||||
|
||||
def __init__(self, model_id: str, output_dir: str, verbose: bool = False):
|
||||
self.model_id = model_id
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.verbose = verbose
|
||||
|
||||
# Load the model
|
||||
print(f"Loading Nougat model from {model_id}...")
|
||||
self.model = VisionEncoderDecoderModel.from_pretrained(model_id)
|
||||
self.processor = NougatProcessor.from_pretrained(model_id)
|
||||
|
||||
def extract_vision_encoder(self) -> Dict[str, torch.Tensor]:
|
||||
"""Extract Swin Transformer vision encoder weights"""
|
||||
print("Extracting vision encoder (Swin Transformer)...")
|
||||
|
||||
vision_dict = {}
|
||||
encoder = self.model.encoder
|
||||
|
||||
# Get all encoder parameters
|
||||
for name, param in encoder.named_parameters():
|
||||
# Map to our Swin naming convention
|
||||
mapped_name = self._map_swin_tensor_name(name)
|
||||
vision_dict[mapped_name] = param.detach().cpu()
|
||||
|
||||
if self.verbose:
|
||||
print(f" {name} -> {mapped_name} {list(param.shape)}")
|
||||
|
||||
print(f" Extracted {len(vision_dict)} vision encoder tensors")
|
||||
return vision_dict
|
||||
|
||||
def extract_text_decoder(self) -> Dict[str, torch.Tensor]:
|
||||
"""Extract mBART text decoder weights"""
|
||||
print("Extracting text decoder (mBART)...")
|
||||
|
||||
decoder_dict = {}
|
||||
decoder = self.model.decoder
|
||||
|
||||
# Get all decoder parameters
|
||||
for name, param in decoder.named_parameters():
|
||||
# Map to our mBART naming convention
|
||||
mapped_name = self._map_mbart_tensor_name(name)
|
||||
decoder_dict[mapped_name] = param.detach().cpu()
|
||||
|
||||
if self.verbose:
|
||||
print(f" {name} -> {mapped_name} {list(param.shape)}")
|
||||
|
||||
print(f" Extracted {len(decoder_dict)} text decoder tensors")
|
||||
return decoder_dict
|
||||
|
||||
def extract_projector(self) -> Dict[str, torch.Tensor]:
|
||||
"""Extract multimodal projector that connects vision and text models"""
|
||||
print("Extracting multimodal projector...")
|
||||
|
||||
projector_dict = {}
|
||||
|
||||
# In Nougat, the projection happens through the decoder's cross-attention
|
||||
# We need to extract the projection matrices that connect encoder outputs to decoder
|
||||
|
||||
# Look for cross-attention weights in decoder
|
||||
for name, param in self.model.decoder.named_parameters():
|
||||
if "encoder_attn" in name:
|
||||
# These are the cross-attention weights that project from vision to text
|
||||
mapped_name = self._map_projector_tensor_name(name)
|
||||
projector_dict[mapped_name] = param.detach().cpu()
|
||||
|
||||
if self.verbose:
|
||||
print(f" {name} -> {mapped_name} {list(param.shape)}")
|
||||
|
||||
# If there's a specific projection layer between encoder and decoder
|
||||
if hasattr(self.model, "enc_to_dec_proj"):
|
||||
projector_dict["mm.projector.weight"] = self.model.enc_to_dec_proj.weight.detach().cpu()
|
||||
if hasattr(self.model.enc_to_dec_proj, "bias"):
|
||||
projector_dict["mm.projector.bias"] = self.model.enc_to_dec_proj.bias.detach().cpu()
|
||||
|
||||
print(f" Extracted {len(projector_dict)} projector tensors")
|
||||
return projector_dict
|
||||
|
||||
def _map_swin_tensor_name(self, name: str) -> str:
|
||||
"""Map HuggingFace Swin tensor names to our convention"""
|
||||
|
||||
# Remove model prefix
|
||||
if name.startswith("model.encoder."):
|
||||
name = name[len("model.encoder."):]
|
||||
elif name.startswith("encoder."):
|
||||
name = name[len("encoder."):]
|
||||
|
||||
# Patch embeddings
|
||||
if "embeddings.patch_embeddings" in name:
|
||||
if "projection.weight" in name:
|
||||
return "swin.patch_embed.weight"
|
||||
elif "projection.bias" in name:
|
||||
return "swin.patch_embed.bias"
|
||||
elif "norm" in name:
|
||||
return f"swin.patch_embed.norm.{'weight' if 'weight' in name else 'bias'}"
|
||||
|
||||
# Position embeddings
|
||||
if "position_embeddings" in name:
|
||||
return "swin.pos_embed"
|
||||
|
||||
# Parse layer structure
|
||||
if "layers." in name:
|
||||
parts = name.split(".")
|
||||
stage_idx = None
|
||||
layer_idx = None
|
||||
|
||||
# Find stage and layer indices
|
||||
for i, part in enumerate(parts):
|
||||
if part == "layers" and i + 1 < len(parts):
|
||||
stage_idx = int(parts[i + 1])
|
||||
if part == "blocks" and i + 1 < len(parts):
|
||||
layer_idx = int(parts[i + 1])
|
||||
|
||||
if stage_idx is not None:
|
||||
# Layer-specific components
|
||||
if layer_idx is not None:
|
||||
# Attention
|
||||
if "attn.qkv" in name:
|
||||
suffix = "weight" if "weight" in name else "bias"
|
||||
return f"swin.stage.{stage_idx}.layer.{layer_idx}.attn.qkv.{suffix}"
|
||||
elif "attn.proj" in name:
|
||||
suffix = "weight" if "weight" in name else "bias"
|
||||
return f"swin.stage.{stage_idx}.layer.{layer_idx}.attn.proj.{suffix}"
|
||||
# Norms
|
||||
elif "norm1" in name:
|
||||
suffix = "weight" if "weight" in name else "bias"
|
||||
return f"swin.stage.{stage_idx}.layer.{layer_idx}.norm1.{suffix}"
|
||||
elif "norm2" in name:
|
||||
suffix = "weight" if "weight" in name else "bias"
|
||||
return f"swin.stage.{stage_idx}.layer.{layer_idx}.norm2.{suffix}"
|
||||
# MLP/FFN
|
||||
elif "mlp.fc1" in name:
|
||||
suffix = "weight" if "weight" in name else "bias"
|
||||
return f"swin.stage.{stage_idx}.layer.{layer_idx}.mlp.fc1.{suffix}"
|
||||
elif "mlp.fc2" in name:
|
||||
suffix = "weight" if "weight" in name else "bias"
|
||||
return f"swin.stage.{stage_idx}.layer.{layer_idx}.mlp.fc2.{suffix}"
|
||||
# Relative position bias
|
||||
elif "relative_position_bias_table" in name:
|
||||
return f"swin.stage.{stage_idx}.layer.{layer_idx}.attn.relative_position_bias_table"
|
||||
|
||||
# Downsample layers between stages
|
||||
elif "downsample" in name:
|
||||
if "norm" in name:
|
||||
suffix = "weight" if "weight" in name else "bias"
|
||||
return f"swin.stage.{stage_idx}.downsample.norm.{suffix}"
|
||||
elif "reduction" in name:
|
||||
return f"swin.stage.{stage_idx}.downsample.reduction.weight"
|
||||
|
||||
# Output normalization
|
||||
if "layernorm" in name or "layer_norm" in name:
|
||||
if "final" in name or "output" in name:
|
||||
suffix = "weight" if "weight" in name else "bias"
|
||||
return f"swin.norm.{suffix}"
|
||||
|
||||
# Default mapping
|
||||
return f"swin.{name}"
|
||||
|
||||
def _map_mbart_tensor_name(self, name: str) -> str:
|
||||
"""Map HuggingFace mBART tensor names to our convention"""
|
||||
|
||||
# Remove model prefix
|
||||
if name.startswith("model.decoder."):
|
||||
name = name[len("model.decoder."):]
|
||||
elif name.startswith("decoder."):
|
||||
name = name[len("decoder."):]
|
||||
|
||||
# Embeddings
|
||||
if name == "embed_tokens.weight" or name == "shared.weight":
|
||||
return "token_embd.weight"
|
||||
elif "embed_positions" in name:
|
||||
return "position_embd.weight"
|
||||
|
||||
# Parse decoder layers
|
||||
if "layers." in name:
|
||||
parts = name.split(".")
|
||||
layer_idx = int(parts[1])
|
||||
|
||||
# Self-attention
|
||||
if "self_attn.q_proj" in name:
|
||||
return f"blk.{layer_idx}.attn_q.weight"
|
||||
elif "self_attn.k_proj" in name:
|
||||
return f"blk.{layer_idx}.attn_k.weight"
|
||||
elif "self_attn.v_proj" in name:
|
||||
return f"blk.{layer_idx}.attn_v.weight"
|
||||
elif "self_attn.out_proj" in name:
|
||||
return f"blk.{layer_idx}.attn_o.weight"
|
||||
elif "self_attn_layer_norm" in name:
|
||||
suffix = "weight" if "weight" in name else "bias"
|
||||
return f"blk.{layer_idx}.attn_norm.{suffix}"
|
||||
|
||||
# Cross-attention (encoder-decoder attention)
|
||||
elif "encoder_attn.q_proj" in name:
|
||||
return f"blk.{layer_idx}.attn_q_cross.weight"
|
||||
elif "encoder_attn.k_proj" in name:
|
||||
return f"blk.{layer_idx}.attn_k_cross.weight"
|
||||
elif "encoder_attn.v_proj" in name:
|
||||
return f"blk.{layer_idx}.attn_v_cross.weight"
|
||||
elif "encoder_attn.out_proj" in name:
|
||||
return f"blk.{layer_idx}.attn_o_cross.weight"
|
||||
elif "encoder_attn_layer_norm" in name:
|
||||
suffix = "weight" if "weight" in name else "bias"
|
||||
return f"blk.{layer_idx}.attn_norm_cross.{suffix}"
|
||||
|
||||
# FFN
|
||||
elif "fc1" in name:
|
||||
return f"blk.{layer_idx}.ffn_up.weight"
|
||||
elif "fc2" in name:
|
||||
return f"blk.{layer_idx}.ffn_down.weight"
|
||||
elif "final_layer_norm" in name:
|
||||
suffix = "weight" if "weight" in name else "bias"
|
||||
return f"blk.{layer_idx}.ffn_norm.{suffix}"
|
||||
|
||||
# Output layers
|
||||
elif "layernorm" in name or "layer_norm" in name:
|
||||
suffix = "weight" if "weight" in name else "bias"
|
||||
return f"output_norm.{suffix}"
|
||||
elif "lm_head" in name or "output_projection" in name:
|
||||
return "output.weight"
|
||||
|
||||
# Default mapping
|
||||
return name
|
||||
|
||||
def _map_projector_tensor_name(self, name: str) -> str:
|
||||
"""Map cross-attention tensors to projector names"""
|
||||
|
||||
# Extract layer index from name
|
||||
if "layers." in name:
|
||||
parts = name.split(".")
|
||||
layer_idx = int(parts[1])
|
||||
|
||||
if "encoder_attn.q_proj" in name:
|
||||
return f"mm.layer.{layer_idx}.q_proj.weight"
|
||||
elif "encoder_attn.k_proj" in name:
|
||||
return f"mm.layer.{layer_idx}.k_proj.weight"
|
||||
elif "encoder_attn.v_proj" in name:
|
||||
return f"mm.layer.{layer_idx}.v_proj.weight"
|
||||
elif "encoder_attn.out_proj" in name:
|
||||
return f"mm.layer.{layer_idx}.out_proj.weight"
|
||||
|
||||
return f"mm.{name}"
|
||||
|
||||
def save_component(self, tensors: Dict[str, torch.Tensor], filename: str, arch_name: str, description: str):
|
||||
"""Save component tensors to GGUF file"""
|
||||
|
||||
output_path = self.output_dir / filename
|
||||
print(f"Saving {arch_name} to {output_path}...")
|
||||
|
||||
writer = gguf.GGUFWriter(str(output_path), arch_name)
|
||||
writer.add_string("general.name", arch_name)
|
||||
writer.add_string("general.description", description)
|
||||
writer.add_string("general.architecture", arch_name.lower())
|
||||
|
||||
# Add tensors
|
||||
for name, tensor in tensors.items():
|
||||
data = tensor.float().cpu().numpy()
|
||||
writer.add_tensor(name, data)
|
||||
|
||||
writer.write_header_to_file()
|
||||
writer.write_kv_data_to_file()
|
||||
writer.write_tensors_to_file()
|
||||
writer.close()
|
||||
|
||||
print(f" Saved {len(tensors)} tensors")
|
||||
|
||||
def perform_surgery(self):
|
||||
"""Main surgery operation - split model into components"""
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Starting Nougat Model Surgery")
|
||||
print("=" * 60)
|
||||
|
||||
# Extract components
|
||||
vision_tensors = self.extract_vision_encoder()
|
||||
text_tensors = self.extract_text_decoder()
|
||||
projector_tensors = self.extract_projector()
|
||||
|
||||
# Save components
|
||||
print("\nSaving components...")
|
||||
|
||||
self.save_component(
|
||||
vision_tensors,
|
||||
"nougat-vision-swin.gguf",
|
||||
"Nougat-Vision-Swin",
|
||||
"Swin Transformer vision encoder from Nougat OCR model"
|
||||
)
|
||||
|
||||
self.save_component(
|
||||
text_tensors,
|
||||
"nougat-text-mbart.gguf",
|
||||
"Nougat-Text-mBART",
|
||||
"mBART text decoder from Nougat OCR model"
|
||||
)
|
||||
|
||||
if projector_tensors:
|
||||
self.save_component(
|
||||
projector_tensors,
|
||||
"nougat-projector.gguf",
|
||||
"Nougat-Projector",
|
||||
"Multimodal projector connecting vision and text models"
|
||||
)
|
||||
|
||||
# Save configuration
|
||||
self.save_config()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Surgery Complete!")
|
||||
print(f"Output files saved to: {self.output_dir}")
|
||||
print("=" * 60)
|
||||
|
||||
def save_config(self):
|
||||
"""Save model configuration for reconstruction"""
|
||||
|
||||
config = {
|
||||
"model_id": self.model_id,
|
||||
"vision_config": {
|
||||
"architecture": "swin",
|
||||
"image_size": 384,
|
||||
"patch_size": 4,
|
||||
"window_size": 7,
|
||||
"num_channels": 3,
|
||||
"depths": [2, 2, 6, 2],
|
||||
"num_heads": [3, 6, 12, 24],
|
||||
},
|
||||
"text_config": {
|
||||
"architecture": "mbart",
|
||||
"vocab_size": self.processor.tokenizer.vocab_size,
|
||||
"max_position_embeddings": 1024,
|
||||
"hidden_size": self.model.config.decoder.hidden_size,
|
||||
"num_layers": self.model.config.decoder.num_hidden_layers,
|
||||
"num_attention_heads": self.model.config.decoder.num_attention_heads,
|
||||
},
|
||||
"components": {
|
||||
"vision": "nougat-vision-swin.gguf",
|
||||
"text": "nougat-text-mbart.gguf",
|
||||
"projector": "nougat-projector.gguf" if self.extract_projector() else None,
|
||||
}
|
||||
}
|
||||
|
||||
config_path = self.output_dir / "nougat-config.json"
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
print(f"\nConfiguration saved to {config_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Nougat Model Surgery - Split model into components")
|
||||
parser.add_argument(
|
||||
"--model-id",
|
||||
type=str,
|
||||
default="facebook/nougat-base",
|
||||
help="HuggingFace model ID or path to local model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./models/nougat-surgery",
|
||||
help="Output directory for split components"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Verbose output showing tensor mappings"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
surgeon = NougatModelSurgeon(args.model_id, args.output_dir, args.verbose)
|
||||
surgeon.perform_surgery()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,475 @@
|
|||
#include "swin.h"
|
||||
#include "clip.h"
|
||||
#include "ggml.h"
|
||||
#include "ggml-alloc.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "gguf.h"
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
|
||||
// Window partition operation - splits input into non-overlapping windows
|
||||
struct ggml_tensor * swin_window_partition(struct ggml_context * ctx, struct ggml_tensor * x, int window_size) {
|
||||
// x shape: [batch_size, height, width, channels]
|
||||
// output shape: [batch_size * num_windows, window_size, window_size, channels]
|
||||
|
||||
int batch_size = x->ne[3];
|
||||
int H = x->ne[2];
|
||||
int W = x->ne[1];
|
||||
int C = x->ne[0];
|
||||
|
||||
int nH = H / window_size;
|
||||
int nW = W / window_size;
|
||||
|
||||
// Reshape to [batch_size, nH, window_size, nW, window_size, C]
|
||||
struct ggml_tensor * reshaped = ggml_reshape_4d(ctx, x,
|
||||
C * window_size,
|
||||
window_size * nW,
|
||||
nH,
|
||||
batch_size);
|
||||
|
||||
// Permute to [batch_size, nH, nW, window_size, window_size, C]
|
||||
struct ggml_tensor * permuted = ggml_permute(ctx, reshaped, 0, 2, 1, 3);
|
||||
|
||||
// Reshape to [batch_size * nH * nW, window_size, window_size, C]
|
||||
struct ggml_tensor * output = ggml_reshape_4d(ctx, permuted,
|
||||
C,
|
||||
window_size,
|
||||
window_size,
|
||||
batch_size * nH * nW);
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
// Window reverse operation - merges windows back to original spatial dimensions
|
||||
struct ggml_tensor * swin_window_reverse(struct ggml_context * ctx, struct ggml_tensor * windows, int window_size, int H, int W) {
|
||||
// windows shape: [batch_size * num_windows, window_size, window_size, channels]
|
||||
// output shape: [batch_size, height, width, channels]
|
||||
|
||||
int C = windows->ne[0];
|
||||
int nH = H / window_size;
|
||||
int nW = W / window_size;
|
||||
int batch_size = windows->ne[3] / (nH * nW);
|
||||
|
||||
// Reshape to [batch_size, nH, nW, window_size, window_size, C]
|
||||
struct ggml_tensor * reshaped = ggml_reshape_4d(ctx, windows,
|
||||
C * window_size * window_size,
|
||||
nW,
|
||||
nH,
|
||||
batch_size);
|
||||
|
||||
// Permute to [batch_size, nH, window_size, nW, window_size, C]
|
||||
struct ggml_tensor * permuted = ggml_permute(ctx, reshaped, 0, 2, 1, 3);
|
||||
|
||||
// Reshape to [batch_size, H, W, C]
|
||||
struct ggml_tensor * output = ggml_reshape_4d(ctx, permuted, C, W, H, batch_size);
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
// Create attention mask for shifted window attention
|
||||
struct ggml_tensor * swin_create_window_mask(struct ggml_context * ctx, int window_size, int shift_size, int H, int W) {
|
||||
if (shift_size == 0) {
|
||||
return nullptr; // No mask needed for non-shifted windows
|
||||
}
|
||||
|
||||
// Create a mask tensor
|
||||
struct ggml_tensor * mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, H, W);
|
||||
|
||||
// Initialize mask with region indices
|
||||
float * mask_data = (float *)mask->data;
|
||||
int h_slices[] = {0, H - window_size, H - shift_size, H};
|
||||
int w_slices[] = {0, W - window_size, W - shift_size, W};
|
||||
|
||||
int cnt = 0;
|
||||
for (int i = 0; i < 3; i++) {
|
||||
for (int j = 0; j < 3; j++) {
|
||||
for (int h = h_slices[i]; h < h_slices[i + 1]; h++) {
|
||||
for (int w = w_slices[j]; w < w_slices[j + 1]; w++) {
|
||||
mask_data[h * W + w] = cnt;
|
||||
}
|
||||
}
|
||||
cnt++;
|
||||
}
|
||||
}
|
||||
|
||||
return mask;
|
||||
}
|
||||
|
||||
// Build window attention layer
|
||||
static struct ggml_tensor * swin_window_attention(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * x,
|
||||
const swin_layer & layer,
|
||||
int num_heads,
|
||||
int window_size,
|
||||
bool shifted) {
|
||||
|
||||
int batch_size = x->ne[3];
|
||||
int seq_len = x->ne[2] * x->ne[1]; // window_size * window_size
|
||||
int hidden_dim = x->ne[0];
|
||||
int head_dim = hidden_dim / num_heads;
|
||||
|
||||
// Reshape input for attention: [batch_size, seq_len, hidden_dim]
|
||||
x = ggml_reshape_3d(ctx, x, hidden_dim, seq_len, batch_size);
|
||||
|
||||
// Layer norm
|
||||
x = ggml_norm(ctx, x, layer.ln1_w->ne[0]);
|
||||
x = ggml_add(ctx, ggml_mul(ctx, x, layer.ln1_w), layer.ln1_b);
|
||||
|
||||
// QKV projection
|
||||
struct ggml_tensor * qkv = ggml_mul_mat(ctx, layer.qkv_w, x);
|
||||
qkv = ggml_add(ctx, qkv, layer.qkv_b);
|
||||
|
||||
// Split into Q, K, V
|
||||
int qkv_dim = qkv->ne[0] / 3;
|
||||
struct ggml_tensor * q = ggml_view_3d(ctx, qkv, qkv_dim, seq_len, batch_size, qkv->nb[1], qkv->nb[2], 0);
|
||||
struct ggml_tensor * k = ggml_view_3d(ctx, qkv, qkv_dim, seq_len, batch_size, qkv->nb[1], qkv->nb[2], qkv_dim * ggml_element_size(qkv));
|
||||
struct ggml_tensor * v = ggml_view_3d(ctx, qkv, qkv_dim, seq_len, batch_size, qkv->nb[1], qkv->nb[2], 2 * qkv_dim * ggml_element_size(qkv));
|
||||
|
||||
// Reshape for multi-head attention
|
||||
q = ggml_reshape_4d(ctx, q, head_dim, num_heads, seq_len, batch_size);
|
||||
k = ggml_reshape_4d(ctx, k, head_dim, num_heads, seq_len, batch_size);
|
||||
v = ggml_reshape_4d(ctx, v, head_dim, num_heads, seq_len, batch_size);
|
||||
|
||||
// Transpose for attention: [batch_size, num_heads, seq_len, head_dim]
|
||||
q = ggml_permute(ctx, q, 0, 2, 1, 3);
|
||||
k = ggml_permute(ctx, k, 0, 2, 1, 3);
|
||||
v = ggml_permute(ctx, v, 0, 2, 1, 3);
|
||||
|
||||
// Scaled dot-product attention
|
||||
float scale = 1.0f / sqrtf(head_dim);
|
||||
struct ggml_tensor * attn = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, k)), q);
|
||||
attn = ggml_scale(ctx, attn, scale);
|
||||
|
||||
// Add relative position bias if available
|
||||
if (layer.relative_position_bias_table != nullptr) {
|
||||
// This would need proper indexing based on relative positions
|
||||
// For now, simplified version
|
||||
attn = ggml_add(ctx, attn, layer.relative_position_bias_table);
|
||||
}
|
||||
|
||||
// Apply mask for shifted window attention
|
||||
if (shifted) {
|
||||
// Create and apply attention mask
|
||||
struct ggml_tensor * mask = swin_create_window_mask(ctx, window_size, window_size / 2,
|
||||
window_size, window_size);
|
||||
if (mask != nullptr) {
|
||||
// Convert mask to attention mask
|
||||
attn = ggml_add(ctx, attn, mask);
|
||||
}
|
||||
}
|
||||
|
||||
// Softmax
|
||||
attn = ggml_soft_max(ctx, attn);
|
||||
|
||||
// Apply attention to values
|
||||
struct ggml_tensor * out = ggml_mul_mat(ctx, v, attn);
|
||||
|
||||
// Transpose back: [batch_size, seq_len, num_heads, head_dim]
|
||||
out = ggml_permute(ctx, out, 0, 2, 1, 3);
|
||||
|
||||
// Reshape to merge heads: [batch_size, seq_len, hidden_dim]
|
||||
out = ggml_reshape_3d(ctx, out, hidden_dim, seq_len, batch_size);
|
||||
|
||||
// Output projection
|
||||
out = ggml_mul_mat(ctx, layer.proj_w, out);
|
||||
out = ggml_add(ctx, out, layer.proj_b);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
// Build FFN layer
|
||||
static struct ggml_tensor * swin_ffn(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * x,
|
||||
const swin_layer & layer,
|
||||
float mlp_ratio) {
|
||||
|
||||
// Layer norm
|
||||
x = ggml_norm(ctx, x, layer.ln2_w->ne[0]);
|
||||
x = ggml_add(ctx, ggml_mul(ctx, x, layer.ln2_w), layer.ln2_b);
|
||||
|
||||
// FFN: Linear -> GELU -> Linear
|
||||
x = ggml_mul_mat(ctx, layer.fc1_w, x);
|
||||
x = ggml_add(ctx, x, layer.fc1_b);
|
||||
x = ggml_gelu(ctx, x);
|
||||
|
||||
x = ggml_mul_mat(ctx, layer.fc2_w, x);
|
||||
x = ggml_add(ctx, x, layer.fc2_b);
|
||||
|
||||
return x;
|
||||
}
|
||||
|
||||
// Build Swin Transformer block
|
||||
static struct ggml_tensor * swin_block(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * x,
|
||||
const swin_layer & layer,
|
||||
int num_heads,
|
||||
int window_size,
|
||||
bool shifted,
|
||||
float mlp_ratio) {
|
||||
|
||||
int H = x->ne[2];
|
||||
int W = x->ne[1];
|
||||
|
||||
struct ggml_tensor * shortcut = x;
|
||||
|
||||
// Shifted window partitioning if needed
|
||||
if (shifted && (H > window_size || W > window_size)) {
|
||||
// Cyclic shift
|
||||
int shift_size = window_size / 2;
|
||||
x = ggml_roll(ctx, x, -shift_size, 2); // Roll along H dimension
|
||||
x = ggml_roll(ctx, x, -shift_size, 1); // Roll along W dimension
|
||||
}
|
||||
|
||||
// Partition into windows
|
||||
if (H > window_size || W > window_size) {
|
||||
x = swin_window_partition(ctx, x, window_size);
|
||||
}
|
||||
|
||||
// Window attention
|
||||
x = swin_window_attention(ctx, x, layer, num_heads, window_size, shifted);
|
||||
|
||||
// Reverse window partition
|
||||
if (H > window_size || W > window_size) {
|
||||
x = swin_window_reverse(ctx, x, window_size, H, W);
|
||||
}
|
||||
|
||||
// Reverse cyclic shift if needed
|
||||
if (shifted && (H > window_size || W > window_size)) {
|
||||
int shift_size = window_size / 2;
|
||||
x = ggml_roll(ctx, x, shift_size, 2); // Roll back along H dimension
|
||||
x = ggml_roll(ctx, x, shift_size, 1); // Roll back along W dimension
|
||||
}
|
||||
|
||||
// Residual connection
|
||||
x = ggml_add(ctx, x, shortcut);
|
||||
|
||||
// FFN with residual
|
||||
shortcut = x;
|
||||
x = swin_ffn(ctx, x, layer, mlp_ratio);
|
||||
x = ggml_add(ctx, x, shortcut);
|
||||
|
||||
return x;
|
||||
}
|
||||
|
||||
// Patch merging layer (downsampling)
|
||||
static struct ggml_tensor * swin_patch_merging(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * x,
|
||||
struct ggml_tensor * norm_w,
|
||||
struct ggml_tensor * norm_b,
|
||||
struct ggml_tensor * reduction) {
|
||||
|
||||
int batch_size = x->ne[3];
|
||||
int H = x->ne[2];
|
||||
int W = x->ne[1];
|
||||
int C = x->ne[0];
|
||||
|
||||
// Reshape to merge 2x2 patches
|
||||
x = ggml_reshape_4d(ctx, x, C, W/2, 2, H/2 * 2 * batch_size);
|
||||
x = ggml_permute(ctx, x, 0, 2, 1, 3);
|
||||
x = ggml_reshape_4d(ctx, x, C * 4, W/2, H/2, batch_size);
|
||||
|
||||
// Layer norm
|
||||
x = ggml_norm(ctx, x, norm_w->ne[0]);
|
||||
x = ggml_add(ctx, ggml_mul(ctx, x, norm_w), norm_b);
|
||||
|
||||
// Linear reduction
|
||||
x = ggml_mul_mat(ctx, reduction, x);
|
||||
|
||||
return x;
|
||||
}
|
||||
|
||||
// Build complete Swin Transformer graph
|
||||
struct ggml_cgraph * swin_build_graph(
|
||||
struct swin_ctx * ctx,
|
||||
const swin_image_batch * imgs,
|
||||
std::pair<int, int> load_image_size,
|
||||
bool is_inf) {
|
||||
|
||||
if (!ctx->has_vision_encoder) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto & model = ctx->vision_model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ ctx->buf_compute_meta.size(),
|
||||
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
struct ggml_cgraph * cgraph = ggml_new_graph(ctx0);
|
||||
|
||||
const int batch_size = imgs->size;
|
||||
const int image_size = hparams.image_size;
|
||||
const int patch_size = hparams.patch_size;
|
||||
const int num_patches_side = image_size / patch_size;
|
||||
const int num_patches = num_patches_side * num_patches_side;
|
||||
const int hidden_dim = hparams.hidden_dim;
|
||||
|
||||
// Input image tensor
|
||||
struct ggml_tensor * inp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32,
|
||||
3, image_size, image_size, batch_size);
|
||||
ggml_set_name(inp, "inp");
|
||||
|
||||
// Patch embedding: Conv2D with stride=patch_size
|
||||
struct ggml_tensor * x = ggml_conv_2d(ctx0, model.patch_embed, inp, patch_size, patch_size, 0, 0, 1, 1);
|
||||
|
||||
// Reshape to [batch_size, num_patches, hidden_dim]
|
||||
x = ggml_reshape_3d(ctx0, x, hidden_dim, num_patches, batch_size);
|
||||
|
||||
// Add positional embeddings if available
|
||||
if (model.pos_embed != nullptr) {
|
||||
x = ggml_add(ctx0, x, model.pos_embed);
|
||||
}
|
||||
|
||||
// Layer norm after patch embedding
|
||||
if (model.patch_norm_w != nullptr) {
|
||||
x = ggml_norm(ctx0, x, model.patch_norm_w->ne[0]);
|
||||
x = ggml_add(ctx0, ggml_mul(ctx0, x, model.patch_norm_w), model.patch_norm_b);
|
||||
}
|
||||
|
||||
// Reshape for spatial processing
|
||||
x = ggml_reshape_4d(ctx0, x, hidden_dim, num_patches_side, num_patches_side, batch_size);
|
||||
|
||||
// Process through Swin stages
|
||||
int H = num_patches_side;
|
||||
int W = num_patches_side;
|
||||
int C = hidden_dim;
|
||||
|
||||
for (size_t stage_idx = 0; stage_idx < model.stages.size(); stage_idx++) {
|
||||
const auto & stage = model.stages[stage_idx];
|
||||
|
||||
// Process layers in this stage
|
||||
for (size_t layer_idx = 0; layer_idx < stage.layers.size(); layer_idx++) {
|
||||
const auto & layer = stage.layers[layer_idx];
|
||||
bool shifted = (layer_idx % 2 == 1); // Alternate between regular and shifted windows
|
||||
|
||||
x = swin_block(ctx0, x, layer,
|
||||
hparams.num_heads[stage_idx],
|
||||
hparams.window_size,
|
||||
shifted,
|
||||
hparams.mlp_ratio);
|
||||
}
|
||||
|
||||
// Patch merging (downsampling) between stages, except for the last stage
|
||||
if (stage_idx < model.stages.size() - 1 && stage.downsample_reduction != nullptr) {
|
||||
x = swin_patch_merging(ctx0, x,
|
||||
stage.downsample_norm_w,
|
||||
stage.downsample_norm_b,
|
||||
stage.downsample_reduction);
|
||||
H /= 2;
|
||||
W /= 2;
|
||||
C *= 2; // Channel dimension doubles after patch merging
|
||||
}
|
||||
}
|
||||
|
||||
// Global average pooling
|
||||
x = ggml_reshape_3d(ctx0, x, C, H * W, batch_size);
|
||||
x = ggml_mean(ctx0, x); // Average over spatial dimensions
|
||||
|
||||
// Final layer norm
|
||||
if (model.output_norm_w != nullptr) {
|
||||
x = ggml_norm(ctx0, x, model.output_norm_w->ne[0]);
|
||||
x = ggml_add(ctx0, ggml_mul(ctx0, x, model.output_norm_w), model.output_norm_b);
|
||||
}
|
||||
|
||||
ggml_set_name(x, "output");
|
||||
ggml_build_forward_expand(cgraph, x);
|
||||
|
||||
return cgraph;
|
||||
}
|
||||
|
||||
// Model loading function
|
||||
struct swin_ctx * swin_model_load(const std::string & fname, int verbosity) {
|
||||
struct swin_ctx * ctx = new swin_ctx();
|
||||
|
||||
struct gguf_init_params params = {
|
||||
/*.no_alloc = */ true,
|
||||
/*.ctx = */ &ctx->ctx,
|
||||
};
|
||||
|
||||
struct gguf_context * gguf_ctx = gguf_init_from_file(fname.c_str(), params);
|
||||
if (!gguf_ctx) {
|
||||
fprintf(stderr, "%s: failed to load model from %s\n", __func__, fname.c_str());
|
||||
swin_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Load hyperparameters
|
||||
auto & hparams = ctx->vision_model.hparams;
|
||||
|
||||
// Read Swin-specific parameters from GGUF
|
||||
const int n_kv = gguf_get_n_kv(gguf_ctx);
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
const char * key = gguf_get_key(gguf_ctx, i);
|
||||
|
||||
if (strcmp(key, KEY_SWIN_WINDOW_SIZE) == 0) {
|
||||
hparams.window_size = gguf_get_val_i32(gguf_ctx, i);
|
||||
} else if (strcmp(key, KEY_SWIN_PATCH_SIZE) == 0) {
|
||||
hparams.patch_size = gguf_get_val_i32(gguf_ctx, i);
|
||||
} else if (strcmp(key, KEY_SWIN_IMAGE_SIZE) == 0) {
|
||||
hparams.image_size = gguf_get_val_i32(gguf_ctx, i);
|
||||
} else if (strcmp(key, KEY_SWIN_HIDDEN_DIM) == 0) {
|
||||
hparams.hidden_dim = gguf_get_val_i32(gguf_ctx, i);
|
||||
} else if (strcmp(key, KEY_SWIN_MLP_RATIO) == 0) {
|
||||
hparams.mlp_ratio = gguf_get_val_f32(gguf_ctx, i);
|
||||
} else if (strcmp(key, KEY_SWIN_NORM_EPS) == 0) {
|
||||
hparams.norm_eps = gguf_get_val_f32(gguf_ctx, i);
|
||||
}
|
||||
// TODO: Load depths and num_heads arrays
|
||||
}
|
||||
|
||||
ctx->has_vision_encoder = true;
|
||||
|
||||
if (verbosity >= 1) {
|
||||
printf("Swin Transformer model loaded:\n");
|
||||
printf(" image_size: %d\n", hparams.image_size);
|
||||
printf(" patch_size: %d\n", hparams.patch_size);
|
||||
printf(" window_size: %d\n", hparams.window_size);
|
||||
printf(" hidden_dim: %d\n", hparams.hidden_dim);
|
||||
printf(" num_stages: %d\n", hparams.num_stages());
|
||||
}
|
||||
|
||||
// TODO: Load actual tensor weights from GGUF file
|
||||
|
||||
gguf_free(gguf_ctx);
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// Free context
|
||||
void swin_free(struct swin_ctx * ctx) {
|
||||
if (ctx == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (ctx->backend) {
|
||||
ggml_backend_free(ctx->backend);
|
||||
}
|
||||
|
||||
if (ctx->params_buffer) {
|
||||
ggml_backend_buffer_free(ctx->params_buffer);
|
||||
}
|
||||
|
||||
if (ctx->compute_buffer) {
|
||||
ggml_backend_buffer_free(ctx->compute_buffer);
|
||||
}
|
||||
|
||||
if (ctx->ctx) {
|
||||
ggml_free(ctx->ctx);
|
||||
}
|
||||
|
||||
delete ctx;
|
||||
}
|
||||
|
|
@ -0,0 +1,153 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "clip-impl.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
// Swin Transformer constants
|
||||
#define KEY_SWIN_WINDOW_SIZE "swin.window_size"
|
||||
#define KEY_SWIN_PATCH_SIZE "swin.patch_size"
|
||||
#define KEY_SWIN_IMAGE_SIZE "swin.image_size"
|
||||
#define KEY_SWIN_DEPTHS "swin.depths"
|
||||
#define KEY_SWIN_NUM_HEADS "swin.num_heads"
|
||||
#define KEY_SWIN_HIDDEN_DIM "swin.hidden_dim"
|
||||
#define KEY_SWIN_NUM_CHANNELS "swin.num_channels"
|
||||
#define KEY_SWIN_MLP_RATIO "swin.mlp_ratio"
|
||||
#define KEY_SWIN_DROP_PATH_RATE "swin.drop_path_rate"
|
||||
#define KEY_SWIN_NORM_EPS "swin.norm_eps"
|
||||
|
||||
// Tensor names for Swin Transformer
|
||||
#define TN_SWIN_PATCH_EMBED "swin.patch_embed.weight"
|
||||
#define TN_SWIN_PATCH_NORM "swin.patch_embed.norm.%s"
|
||||
#define TN_SWIN_POS_EMBED "swin.pos_embed"
|
||||
#define TN_SWIN_DOWNSAMPLE_NORM "swin.stage.%d.downsample.norm.%s"
|
||||
#define TN_SWIN_DOWNSAMPLE_PROJ "swin.stage.%d.downsample.reduction.weight"
|
||||
#define TN_SWIN_ATTN_NORM "swin.stage.%d.layer.%d.norm1.%s"
|
||||
#define TN_SWIN_ATTN_QKV "swin.stage.%d.layer.%d.attn.qkv.%s"
|
||||
#define TN_SWIN_ATTN_PROJ "swin.stage.%d.layer.%d.attn.proj.%s"
|
||||
#define TN_SWIN_ATTN_REL_POS "swin.stage.%d.layer.%d.attn.relative_position_bias_table"
|
||||
#define TN_SWIN_FFN_NORM "swin.stage.%d.layer.%d.norm2.%s"
|
||||
#define TN_SWIN_FFN_FC1 "swin.stage.%d.layer.%d.mlp.fc1.%s"
|
||||
#define TN_SWIN_FFN_FC2 "swin.stage.%d.layer.%d.mlp.fc2.%s"
|
||||
#define TN_SWIN_OUTPUT_NORM "swin.norm.%s"
|
||||
|
||||
// Forward declarations
|
||||
struct swin_ctx;
|
||||
|
||||
// Swin Transformer hyperparameters
|
||||
struct swin_hparams {
|
||||
int32_t image_size = 384;
|
||||
int32_t patch_size = 4;
|
||||
int32_t num_channels = 3;
|
||||
int32_t window_size = 7;
|
||||
int32_t hidden_dim = 96;
|
||||
std::vector<int32_t> depths = {2, 2, 6, 2}; // depths for each stage
|
||||
std::vector<int32_t> num_heads = {3, 6, 12, 24}; // number of heads for each stage
|
||||
float mlp_ratio = 4.0f;
|
||||
float drop_path_rate = 0.1f;
|
||||
float norm_eps = 1e-5f;
|
||||
bool use_checkpoint = false;
|
||||
|
||||
// Computed values
|
||||
int32_t num_stages() const { return depths.size(); }
|
||||
int32_t num_patches() const { return (image_size / patch_size) * (image_size / patch_size); }
|
||||
};
|
||||
|
||||
// Swin Transformer layer
|
||||
struct swin_layer {
|
||||
// Window attention
|
||||
struct ggml_tensor * ln1_w;
|
||||
struct ggml_tensor * ln1_b;
|
||||
struct ggml_tensor * qkv_w;
|
||||
struct ggml_tensor * qkv_b;
|
||||
struct ggml_tensor * proj_w;
|
||||
struct ggml_tensor * proj_b;
|
||||
struct ggml_tensor * relative_position_bias_table;
|
||||
|
||||
// FFN
|
||||
struct ggml_tensor * ln2_w;
|
||||
struct ggml_tensor * ln2_b;
|
||||
struct ggml_tensor * fc1_w;
|
||||
struct ggml_tensor * fc1_b;
|
||||
struct ggml_tensor * fc2_w;
|
||||
struct ggml_tensor * fc2_b;
|
||||
};
|
||||
|
||||
// Swin Transformer stage
|
||||
struct swin_stage {
|
||||
std::vector<swin_layer> layers;
|
||||
|
||||
// Patch merging (downsample) layer
|
||||
struct ggml_tensor * downsample_norm_w = nullptr;
|
||||
struct ggml_tensor * downsample_norm_b = nullptr;
|
||||
struct ggml_tensor * downsample_reduction = nullptr;
|
||||
};
|
||||
|
||||
// Swin Transformer vision model
|
||||
struct swin_vision_model {
|
||||
swin_hparams hparams;
|
||||
|
||||
// Patch embedding
|
||||
struct ggml_tensor * patch_embed;
|
||||
struct ggml_tensor * patch_norm_w;
|
||||
struct ggml_tensor * patch_norm_b;
|
||||
struct ggml_tensor * pos_embed;
|
||||
|
||||
// Stages
|
||||
std::vector<swin_stage> stages;
|
||||
|
||||
// Output norm
|
||||
struct ggml_tensor * output_norm_w;
|
||||
struct ggml_tensor * output_norm_b;
|
||||
};
|
||||
|
||||
// Main Swin context
|
||||
struct swin_ctx {
|
||||
bool has_vision_encoder = false;
|
||||
bool has_projector = false;
|
||||
|
||||
swin_vision_model vision_model;
|
||||
|
||||
// Backend and compute
|
||||
struct ggml_backend * backend = nullptr;
|
||||
ggml_backend_buffer_t params_buffer = nullptr;
|
||||
|
||||
struct ggml_context * ctx = nullptr;
|
||||
std::vector<uint8_t> buf_compute_meta;
|
||||
|
||||
// GGML compute resources
|
||||
struct ggml_backend_buffer * compute_buffer = nullptr;
|
||||
struct ggml_context * ctx_compute = nullptr;
|
||||
struct ggml_alloc * compute_alloc = nullptr;
|
||||
};
|
||||
|
||||
// Public API functions
|
||||
struct swin_ctx * swin_model_load(const std::string & fname, int verbosity = 1);
|
||||
void swin_free(struct swin_ctx * ctx);
|
||||
|
||||
// Build Swin Transformer graph for inference
|
||||
struct ggml_cgraph * swin_build_graph(
|
||||
struct swin_ctx * ctx,
|
||||
const swin_image_batch * imgs,
|
||||
std::pair<int, int> load_image_size = {0, 0},
|
||||
bool is_inf = false);
|
||||
|
||||
// Encode image batch
|
||||
bool swin_image_batch_encode(
|
||||
struct swin_ctx * ctx,
|
||||
int n_threads,
|
||||
const swin_image_batch * imgs,
|
||||
float * vec);
|
||||
|
||||
// Utility functions
|
||||
int swin_patch_size(const struct swin_ctx * ctx);
|
||||
bool swin_image_preprocess(struct swin_ctx * ctx, const swin_image_u8 * img, swin_image_f32 * res);
|
||||
bool swin_image_batch_preprocess(struct swin_ctx * ctx, int n_threads, const swin_image_batch * imgs, swin_image_f32_batch * res_batch);
|
||||
|
||||
// Window operations for Swin Transformer
|
||||
struct ggml_tensor * swin_window_partition(struct ggml_context * ctx, struct ggml_tensor * x, int window_size);
|
||||
struct ggml_tensor * swin_window_reverse(struct ggml_context * ctx, struct ggml_tensor * windows, int window_size, int H, int W);
|
||||
struct ggml_tensor * swin_create_window_mask(struct ggml_context * ctx, int window_size, int shift_size, int H, int W);
|
||||
struct ggml_tensor * swin_compute_mask(struct ggml_context * ctx, int window_size, int shift_size, int H, int W);
|
||||
|
|
@ -0,0 +1,539 @@
|
|||
#include "common.h"
|
||||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "mtmd/swin.h"
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
|
||||
// External preprocessing function
|
||||
extern "C" bool nougat_preprocess_pipeline(
|
||||
const char* input_path,
|
||||
float** output_data,
|
||||
int* output_width,
|
||||
int* output_height,
|
||||
int* num_pages);
|
||||
|
||||
extern "C" void nougat_preprocess_cleanup(float* data);
|
||||
|
||||
// CLI arguments structure
|
||||
struct nougat_params {
|
||||
std::string input_path = "";
|
||||
std::string output_path = "";
|
||||
std::string vision_model = "models/nougat-vision-swin.gguf";
|
||||
std::string text_model = "models/nougat-text-mbart.gguf";
|
||||
std::string projector_model = "models/nougat-projector.gguf";
|
||||
|
||||
// Processing options
|
||||
bool batch_mode = false;
|
||||
int batch_size = 1;
|
||||
int n_threads = 4;
|
||||
int n_gpu_layers = 0;
|
||||
|
||||
// Output options
|
||||
std::string output_format = "markdown"; // markdown, latex, plain
|
||||
bool verbose = false;
|
||||
bool save_intermediate = false;
|
||||
|
||||
// Performance options
|
||||
bool use_mmap = true;
|
||||
bool use_flash_attn = false;
|
||||
int context_size = 2048;
|
||||
|
||||
// Document-specific options
|
||||
bool deskew = true;
|
||||
bool denoise = true;
|
||||
bool detect_tables = true;
|
||||
bool detect_math = true;
|
||||
int max_pages = -1; // -1 for all pages
|
||||
};
|
||||
|
||||
static void print_usage(const char* prog_name) {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "Nougat OCR - Neural Optical Understanding for Academic Documents\n");
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "Usage: %s [options] -i input_file -o output_file\n", prog_name);
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "Options:\n");
|
||||
fprintf(stdout, " -i, --input FILE Input document (PDF, PNG, JPG)\n");
|
||||
fprintf(stdout, " -o, --output FILE Output file path\n");
|
||||
fprintf(stdout, " --vision-model FILE Path to vision model GGUF (default: models/nougat-vision-swin.gguf)\n");
|
||||
fprintf(stdout, " --text-model FILE Path to text model GGUF (default: models/nougat-text-mbart.gguf)\n");
|
||||
fprintf(stdout, " --projector FILE Path to projector model GGUF (default: models/nougat-projector.gguf)\n");
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, " Processing Options:\n");
|
||||
fprintf(stdout, " -t, --threads N Number of threads (default: 4)\n");
|
||||
fprintf(stdout, " -ngl, --n-gpu-layers N Number of layers to offload to GPU (default: 0)\n");
|
||||
fprintf(stdout, " -b, --batch-size N Batch size for processing (default: 1)\n");
|
||||
fprintf(stdout, " -c, --context-size N Context size (default: 2048)\n");
|
||||
fprintf(stdout, " --max-pages N Maximum pages to process (default: all)\n");
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, " Output Options:\n");
|
||||
fprintf(stdout, " -f, --format FORMAT Output format: markdown, latex, plain (default: markdown)\n");
|
||||
fprintf(stdout, " -v, --verbose Verbose output\n");
|
||||
fprintf(stdout, " --save-intermediate Save intermediate processing results\n");
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, " Document Processing:\n");
|
||||
fprintf(stdout, " --no-deskew Disable automatic deskewing\n");
|
||||
fprintf(stdout, " --no-denoise Disable denoising\n");
|
||||
fprintf(stdout, " --no-tables Disable table detection\n");
|
||||
fprintf(stdout, " --no-math Disable math formula detection\n");
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, " Performance Options:\n");
|
||||
fprintf(stdout, " --no-mmap Disable memory mapping\n");
|
||||
fprintf(stdout, " --flash-attn Use flash attention\n");
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "Examples:\n");
|
||||
fprintf(stdout, " # Basic OCR of a PDF document\n");
|
||||
fprintf(stdout, " %s -i paper.pdf -o paper.md\n", prog_name);
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, " # Process with GPU acceleration\n");
|
||||
fprintf(stdout, " %s -i scan.png -o text.md -ngl 32 -t 8\n", prog_name);
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, " # LaTeX output with math detection\n");
|
||||
fprintf(stdout, " %s -i math_paper.pdf -o paper.tex -f latex --detect-math\n", prog_name);
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
|
||||
static bool parse_args(int argc, char** argv, nougat_params& params) {
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg == "-i" || arg == "--input") {
|
||||
if (++i >= argc) {
|
||||
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
|
||||
return false;
|
||||
}
|
||||
params.input_path = argv[i];
|
||||
}
|
||||
else if (arg == "-o" || arg == "--output") {
|
||||
if (++i >= argc) {
|
||||
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
|
||||
return false;
|
||||
}
|
||||
params.output_path = argv[i];
|
||||
}
|
||||
else if (arg == "--vision-model") {
|
||||
if (++i >= argc) {
|
||||
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
|
||||
return false;
|
||||
}
|
||||
params.vision_model = argv[i];
|
||||
}
|
||||
else if (arg == "--text-model") {
|
||||
if (++i >= argc) {
|
||||
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
|
||||
return false;
|
||||
}
|
||||
params.text_model = argv[i];
|
||||
}
|
||||
else if (arg == "--projector") {
|
||||
if (++i >= argc) {
|
||||
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
|
||||
return false;
|
||||
}
|
||||
params.projector_model = argv[i];
|
||||
}
|
||||
else if (arg == "-t" || arg == "--threads") {
|
||||
if (++i >= argc) {
|
||||
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
|
||||
return false;
|
||||
}
|
||||
params.n_threads = std::stoi(argv[i]);
|
||||
}
|
||||
else if (arg == "-ngl" || arg == "--n-gpu-layers") {
|
||||
if (++i >= argc) {
|
||||
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
|
||||
return false;
|
||||
}
|
||||
params.n_gpu_layers = std::stoi(argv[i]);
|
||||
}
|
||||
else if (arg == "-b" || arg == "--batch-size") {
|
||||
if (++i >= argc) {
|
||||
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
|
||||
return false;
|
||||
}
|
||||
params.batch_size = std::stoi(argv[i]);
|
||||
}
|
||||
else if (arg == "-c" || arg == "--context-size") {
|
||||
if (++i >= argc) {
|
||||
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
|
||||
return false;
|
||||
}
|
||||
params.context_size = std::stoi(argv[i]);
|
||||
}
|
||||
else if (arg == "--max-pages") {
|
||||
if (++i >= argc) {
|
||||
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
|
||||
return false;
|
||||
}
|
||||
params.max_pages = std::stoi(argv[i]);
|
||||
}
|
||||
else if (arg == "-f" || arg == "--format") {
|
||||
if (++i >= argc) {
|
||||
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
|
||||
return false;
|
||||
}
|
||||
params.output_format = argv[i];
|
||||
}
|
||||
else if (arg == "-v" || arg == "--verbose") {
|
||||
params.verbose = true;
|
||||
}
|
||||
else if (arg == "--save-intermediate") {
|
||||
params.save_intermediate = true;
|
||||
}
|
||||
else if (arg == "--no-deskew") {
|
||||
params.deskew = false;
|
||||
}
|
||||
else if (arg == "--no-denoise") {
|
||||
params.denoise = false;
|
||||
}
|
||||
else if (arg == "--no-tables") {
|
||||
params.detect_tables = false;
|
||||
}
|
||||
else if (arg == "--no-math") {
|
||||
params.detect_math = false;
|
||||
}
|
||||
else if (arg == "--no-mmap") {
|
||||
params.use_mmap = false;
|
||||
}
|
||||
else if (arg == "--flash-attn") {
|
||||
params.use_flash_attn = true;
|
||||
}
|
||||
else if (arg == "-h" || arg == "--help") {
|
||||
print_usage(argv[0]);
|
||||
exit(0);
|
||||
}
|
||||
else {
|
||||
fprintf(stderr, "Error: Unknown argument '%s'\n", arg.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Validate required arguments
|
||||
if (params.input_path.empty()) {
|
||||
fprintf(stderr, "Error: Input file is required\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (params.output_path.empty()) {
|
||||
// Generate default output path
|
||||
size_t dot_pos = params.input_path.find_last_of(".");
|
||||
params.output_path = params.input_path.substr(0, dot_pos);
|
||||
|
||||
if (params.output_format == "markdown") {
|
||||
params.output_path += ".md";
|
||||
} else if (params.output_format == "latex") {
|
||||
params.output_path += ".tex";
|
||||
} else {
|
||||
params.output_path += ".txt";
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Process a single page through the Nougat pipeline
|
||||
static std::string process_page(
|
||||
struct swin_ctx* vision_ctx,
|
||||
struct llama_model* text_model,
|
||||
struct llama_context* text_ctx,
|
||||
const float* image_data,
|
||||
int width,
|
||||
int height,
|
||||
const nougat_params& params) {
|
||||
|
||||
// Step 1: Encode image with Swin Transformer
|
||||
if (params.verbose) {
|
||||
printf("Encoding image with Swin Transformer...\n");
|
||||
}
|
||||
|
||||
// Create image batch
|
||||
swin_image_f32 img = {
|
||||
width,
|
||||
height,
|
||||
3,
|
||||
std::vector<float>(image_data, image_data + width * height * 3)
|
||||
};
|
||||
|
||||
swin_image_batch imgs = {1, &img};
|
||||
|
||||
// Encode image
|
||||
std::vector<float> vision_embeddings(2048); // Adjust size based on model
|
||||
if (!swin_image_batch_encode(vision_ctx, params.n_threads, &imgs, vision_embeddings.data())) {
|
||||
fprintf(stderr, "Failed to encode image\n");
|
||||
return "";
|
||||
}
|
||||
|
||||
// Step 2: Pass embeddings through projector
|
||||
// This would map vision embeddings to text embedding space
|
||||
|
||||
// Step 3: Generate text with mBART decoder
|
||||
if (params.verbose) {
|
||||
printf("Generating text with mBART decoder...\n");
|
||||
}
|
||||
|
||||
// Create batch for text generation
|
||||
llama_batch batch = llama_batch_init(params.context_size, 0, 1);
|
||||
|
||||
// Set up cross-attention with vision embeddings
|
||||
// This requires the decoder to attend to encoder outputs
|
||||
|
||||
// Start with BOS token
|
||||
llama_token bos_token = llama_token_get_bos(text_model);
|
||||
batch.token[0] = bos_token;
|
||||
batch.pos[0] = 0;
|
||||
batch.n_seq_id[0] = 1;
|
||||
batch.seq_id[0][0] = 0;
|
||||
batch.n_tokens = 1;
|
||||
|
||||
// Decode initial token
|
||||
if (llama_decode(text_ctx, batch) != 0) {
|
||||
fprintf(stderr, "Failed to decode\n");
|
||||
llama_batch_free(batch);
|
||||
return "";
|
||||
}
|
||||
|
||||
// Generate text autoregressively
|
||||
std::vector<llama_token> generated_tokens;
|
||||
generated_tokens.push_back(bos_token);
|
||||
|
||||
llama_token eos_token = llama_token_get_eos(text_model);
|
||||
int max_tokens = params.context_size;
|
||||
|
||||
for (int i = 1; i < max_tokens; i++) {
|
||||
// Get logits from last position
|
||||
float* logits = llama_get_logits_ith(text_ctx, batch.n_tokens - 1);
|
||||
|
||||
// Sample next token
|
||||
int n_vocab = llama_n_vocab(text_model);
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
|
||||
|
||||
// Sample with top-k and top-p
|
||||
int top_k = 40;
|
||||
float top_p = 0.9f;
|
||||
float temp = 0.8f;
|
||||
|
||||
llama_sample_top_k(text_ctx, &candidates_p, top_k, 1);
|
||||
llama_sample_top_p(text_ctx, &candidates_p, top_p, 1);
|
||||
llama_sample_temp(text_ctx, &candidates_p, temp);
|
||||
|
||||
llama_token new_token = llama_sample_token(text_ctx, &candidates_p);
|
||||
|
||||
// Check for EOS
|
||||
if (new_token == eos_token) {
|
||||
break;
|
||||
}
|
||||
|
||||
generated_tokens.push_back(new_token);
|
||||
|
||||
// Add to batch for next iteration
|
||||
batch.token[0] = new_token;
|
||||
batch.pos[0] = i;
|
||||
batch.n_tokens = 1;
|
||||
|
||||
if (llama_decode(text_ctx, batch) != 0) {
|
||||
fprintf(stderr, "Failed to continue decoding\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
// Convert tokens to text
|
||||
std::string result;
|
||||
for (auto token : generated_tokens) {
|
||||
std::string piece = llama_token_to_piece(text_ctx, token, true);
|
||||
result += piece;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
nougat_params params;
|
||||
|
||||
// Parse command line arguments
|
||||
if (!parse_args(argc, argv, params)) {
|
||||
print_usage(argv[0]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Print banner
|
||||
printf("\n");
|
||||
printf("╔═══════════════════════════════════════════════════════╗\n");
|
||||
printf("║ Nougat OCR - Document Understanding ║\n");
|
||||
printf("║ Powered by Swin Transformer + mBART ║\n");
|
||||
printf("╚═══════════════════════════════════════════════════════╝\n");
|
||||
printf("\n");
|
||||
|
||||
printf("Input: %s\n", params.input_path.c_str());
|
||||
printf("Output: %s\n", params.output_path.c_str());
|
||||
printf("Format: %s\n", params.output_format.c_str());
|
||||
printf("\n");
|
||||
|
||||
// Initialize backend
|
||||
llama_backend_init();
|
||||
|
||||
// Load vision model (Swin Transformer)
|
||||
printf("Loading vision model from %s...\n", params.vision_model.c_str());
|
||||
struct swin_ctx* vision_ctx = swin_model_load(params.vision_model, params.verbose ? 2 : 1);
|
||||
if (!vision_ctx) {
|
||||
fprintf(stderr, "Failed to load vision model\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Load text model (mBART)
|
||||
printf("Loading text model from %s...\n", params.text_model.c_str());
|
||||
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
model_params.n_gpu_layers = params.n_gpu_layers;
|
||||
model_params.use_mmap = params.use_mmap;
|
||||
|
||||
struct llama_model* text_model = llama_load_model_from_file(
|
||||
params.text_model.c_str(), model_params);
|
||||
if (!text_model) {
|
||||
fprintf(stderr, "Failed to load text model\n");
|
||||
swin_free(vision_ctx);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Create text generation context
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
ctx_params.n_ctx = params.context_size;
|
||||
ctx_params.n_threads = params.n_threads;
|
||||
ctx_params.n_threads_batch = params.n_threads;
|
||||
ctx_params.flash_attn = params.use_flash_attn;
|
||||
|
||||
struct llama_context* text_ctx = llama_new_context_with_model(text_model, ctx_params);
|
||||
if (!text_ctx) {
|
||||
fprintf(stderr, "Failed to create text context\n");
|
||||
llama_free_model(text_model);
|
||||
swin_free(vision_ctx);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Preprocess document
|
||||
printf("Preprocessing document...\n");
|
||||
float* preprocessed_data = nullptr;
|
||||
int width, height, num_pages;
|
||||
|
||||
if (!nougat_preprocess_pipeline(
|
||||
params.input_path.c_str(),
|
||||
&preprocessed_data,
|
||||
&width, &height, &num_pages)) {
|
||||
fprintf(stderr, "Failed to preprocess document\n");
|
||||
llama_free(text_ctx);
|
||||
llama_free_model(text_model);
|
||||
swin_free(vision_ctx);
|
||||
return 1;
|
||||
}
|
||||
|
||||
printf("Document info: %d pages, %dx%d pixels\n", num_pages, width, height);
|
||||
|
||||
// Limit pages if requested
|
||||
if (params.max_pages > 0 && num_pages > params.max_pages) {
|
||||
num_pages = params.max_pages;
|
||||
printf("Processing first %d pages only\n", num_pages);
|
||||
}
|
||||
|
||||
// Process each page
|
||||
std::string full_output;
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
|
||||
for (int page = 0; page < num_pages; page++) {
|
||||
printf("\nProcessing page %d/%d...\n", page + 1, num_pages);
|
||||
|
||||
float* page_data = preprocessed_data + (page * width * height * 3);
|
||||
|
||||
std::string page_text = process_page(
|
||||
vision_ctx, text_model, text_ctx,
|
||||
page_data, width, height, params);
|
||||
|
||||
if (page_text.empty()) {
|
||||
fprintf(stderr, "Warning: Failed to process page %d\n", page + 1);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Add page separator for multi-page documents
|
||||
if (page > 0) {
|
||||
if (params.output_format == "markdown") {
|
||||
full_output += "\n\n---\n\n";
|
||||
} else if (params.output_format == "latex") {
|
||||
full_output += "\n\\newpage\n\n";
|
||||
} else {
|
||||
full_output += "\n\n[Page " + std::to_string(page + 1) + "]\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
full_output += page_text;
|
||||
|
||||
// Save intermediate results if requested
|
||||
if (params.save_intermediate) {
|
||||
std::string intermediate_file = params.output_path + ".page" +
|
||||
std::to_string(page + 1) + ".tmp";
|
||||
std::ofstream tmp_out(intermediate_file);
|
||||
tmp_out << page_text;
|
||||
tmp_out.close();
|
||||
}
|
||||
}
|
||||
|
||||
auto end_time = std::chrono::high_resolution_clock::now();
|
||||
auto duration = std::chrono::duration_cast<std::chrono::seconds>(end_time - start_time);
|
||||
|
||||
// Save final output
|
||||
printf("\nSaving output to %s...\n", params.output_path.c_str());
|
||||
std::ofstream output_file(params.output_path);
|
||||
if (!output_file) {
|
||||
fprintf(stderr, "Failed to open output file\n");
|
||||
} else {
|
||||
// Add format-specific headers/footers
|
||||
if (params.output_format == "latex") {
|
||||
output_file << "\\documentclass{article}\n";
|
||||
output_file << "\\usepackage{amsmath}\n";
|
||||
output_file << "\\usepackage{graphicx}\n";
|
||||
output_file << "\\begin{document}\n\n";
|
||||
}
|
||||
|
||||
output_file << full_output;
|
||||
|
||||
if (params.output_format == "latex") {
|
||||
output_file << "\n\n\\end{document}\n";
|
||||
}
|
||||
|
||||
output_file.close();
|
||||
}
|
||||
|
||||
// Print statistics
|
||||
printf("\n");
|
||||
printf("╔════════════════════════════════════╗\n");
|
||||
printf("║ OCR Complete! ║\n");
|
||||
printf("╠════════════════════════════════════╣\n");
|
||||
printf("║ Pages processed: %-17d ║\n", num_pages);
|
||||
printf("║ Time taken: %-17lds║\n", duration.count());
|
||||
printf("║ Output size: %-17zd ║\n", full_output.size());
|
||||
printf("╚════════════════════════════════════╝\n");
|
||||
|
||||
// Cleanup
|
||||
nougat_preprocess_cleanup(preprocessed_data);
|
||||
llama_free(text_ctx);
|
||||
llama_free_model(text_model);
|
||||
swin_free(vision_ctx);
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
set(TARGET nougat-cli)
|
||||
|
||||
# Add executable
|
||||
add_executable(${TARGET} ../nougat-cli.cpp)
|
||||
|
||||
# Link with llama library
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
llama
|
||||
${CMAKE_THREAD_LIBS_INIT}
|
||||
)
|
||||
|
||||
# Include directories
|
||||
target_include_directories(${TARGET} PRIVATE
|
||||
${CMAKE_SOURCE_DIR}/include
|
||||
${CMAKE_SOURCE_DIR}/tools
|
||||
)
|
||||
|
||||
# Compile flags
|
||||
llama_add_compile_flags()
|
||||
|
||||
# Set output name
|
||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME "nougat-cli")
|
||||
|
||||
# Install target
|
||||
if(LLAMA_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME DESTINATION bin)
|
||||
endif()
|
||||
Loading…
Reference in New Issue