llama.cpp/tools/nougat-cli.cpp

539 lines
19 KiB
C++

#include "common.h"
#include "ggml.h"
#include "llama.h"
#include "mtmd/swin.h"
#include <cassert>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include <thread>
#include <chrono>
// External preprocessing function
extern "C" bool nougat_preprocess_pipeline(
const char* input_path,
float** output_data,
int* output_width,
int* output_height,
int* num_pages);
extern "C" void nougat_preprocess_cleanup(float* data);
// CLI arguments structure
struct nougat_params {
std::string input_path = "";
std::string output_path = "";
std::string vision_model = "models/nougat-vision-swin.gguf";
std::string text_model = "models/nougat-text-mbart.gguf";
std::string projector_model = "models/nougat-projector.gguf";
// Processing options
bool batch_mode = false;
int batch_size = 1;
int n_threads = 4;
int n_gpu_layers = 0;
// Output options
std::string output_format = "markdown"; // markdown, latex, plain
bool verbose = false;
bool save_intermediate = false;
// Performance options
bool use_mmap = true;
bool use_flash_attn = false;
int context_size = 2048;
// Document-specific options
bool deskew = true;
bool denoise = true;
bool detect_tables = true;
bool detect_math = true;
int max_pages = -1; // -1 for all pages
};
static void print_usage(const char* prog_name) {
fprintf(stdout, "\n");
fprintf(stdout, "Nougat OCR - Neural Optical Understanding for Academic Documents\n");
fprintf(stdout, "\n");
fprintf(stdout, "Usage: %s [options] -i input_file -o output_file\n", prog_name);
fprintf(stdout, "\n");
fprintf(stdout, "Options:\n");
fprintf(stdout, " -i, --input FILE Input document (PDF, PNG, JPG)\n");
fprintf(stdout, " -o, --output FILE Output file path\n");
fprintf(stdout, " --vision-model FILE Path to vision model GGUF (default: models/nougat-vision-swin.gguf)\n");
fprintf(stdout, " --text-model FILE Path to text model GGUF (default: models/nougat-text-mbart.gguf)\n");
fprintf(stdout, " --projector FILE Path to projector model GGUF (default: models/nougat-projector.gguf)\n");
fprintf(stdout, "\n");
fprintf(stdout, " Processing Options:\n");
fprintf(stdout, " -t, --threads N Number of threads (default: 4)\n");
fprintf(stdout, " -ngl, --n-gpu-layers N Number of layers to offload to GPU (default: 0)\n");
fprintf(stdout, " -b, --batch-size N Batch size for processing (default: 1)\n");
fprintf(stdout, " -c, --context-size N Context size (default: 2048)\n");
fprintf(stdout, " --max-pages N Maximum pages to process (default: all)\n");
fprintf(stdout, "\n");
fprintf(stdout, " Output Options:\n");
fprintf(stdout, " -f, --format FORMAT Output format: markdown, latex, plain (default: markdown)\n");
fprintf(stdout, " -v, --verbose Verbose output\n");
fprintf(stdout, " --save-intermediate Save intermediate processing results\n");
fprintf(stdout, "\n");
fprintf(stdout, " Document Processing:\n");
fprintf(stdout, " --no-deskew Disable automatic deskewing\n");
fprintf(stdout, " --no-denoise Disable denoising\n");
fprintf(stdout, " --no-tables Disable table detection\n");
fprintf(stdout, " --no-math Disable math formula detection\n");
fprintf(stdout, "\n");
fprintf(stdout, " Performance Options:\n");
fprintf(stdout, " --no-mmap Disable memory mapping\n");
fprintf(stdout, " --flash-attn Use flash attention\n");
fprintf(stdout, "\n");
fprintf(stdout, "Examples:\n");
fprintf(stdout, " # Basic OCR of a PDF document\n");
fprintf(stdout, " %s -i paper.pdf -o paper.md\n", prog_name);
fprintf(stdout, "\n");
fprintf(stdout, " # Process with GPU acceleration\n");
fprintf(stdout, " %s -i scan.png -o text.md -ngl 32 -t 8\n", prog_name);
fprintf(stdout, "\n");
fprintf(stdout, " # LaTeX output with math detection\n");
fprintf(stdout, " %s -i math_paper.pdf -o paper.tex -f latex --detect-math\n", prog_name);
fprintf(stdout, "\n");
}
static bool parse_args(int argc, char** argv, nougat_params& params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-i" || arg == "--input") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.input_path = argv[i];
}
else if (arg == "-o" || arg == "--output") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.output_path = argv[i];
}
else if (arg == "--vision-model") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.vision_model = argv[i];
}
else if (arg == "--text-model") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.text_model = argv[i];
}
else if (arg == "--projector") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.projector_model = argv[i];
}
else if (arg == "-t" || arg == "--threads") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.n_threads = std::stoi(argv[i]);
}
else if (arg == "-ngl" || arg == "--n-gpu-layers") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.n_gpu_layers = std::stoi(argv[i]);
}
else if (arg == "-b" || arg == "--batch-size") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.batch_size = std::stoi(argv[i]);
}
else if (arg == "-c" || arg == "--context-size") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.context_size = std::stoi(argv[i]);
}
else if (arg == "--max-pages") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.max_pages = std::stoi(argv[i]);
}
else if (arg == "-f" || arg == "--format") {
if (++i >= argc) {
fprintf(stderr, "Error: Missing argument for %s\n", arg.c_str());
return false;
}
params.output_format = argv[i];
}
else if (arg == "-v" || arg == "--verbose") {
params.verbose = true;
}
else if (arg == "--save-intermediate") {
params.save_intermediate = true;
}
else if (arg == "--no-deskew") {
params.deskew = false;
}
else if (arg == "--no-denoise") {
params.denoise = false;
}
else if (arg == "--no-tables") {
params.detect_tables = false;
}
else if (arg == "--no-math") {
params.detect_math = false;
}
else if (arg == "--no-mmap") {
params.use_mmap = false;
}
else if (arg == "--flash-attn") {
params.use_flash_attn = true;
}
else if (arg == "-h" || arg == "--help") {
print_usage(argv[0]);
exit(0);
}
else {
fprintf(stderr, "Error: Unknown argument '%s'\n", arg.c_str());
return false;
}
}
// Validate required arguments
if (params.input_path.empty()) {
fprintf(stderr, "Error: Input file is required\n");
return false;
}
if (params.output_path.empty()) {
// Generate default output path
size_t dot_pos = params.input_path.find_last_of(".");
params.output_path = params.input_path.substr(0, dot_pos);
if (params.output_format == "markdown") {
params.output_path += ".md";
} else if (params.output_format == "latex") {
params.output_path += ".tex";
} else {
params.output_path += ".txt";
}
}
return true;
}
// Process a single page through the Nougat pipeline
static std::string process_page(
struct swin_ctx* vision_ctx,
struct llama_model* text_model,
struct llama_context* text_ctx,
const float* image_data,
int width,
int height,
const nougat_params& params) {
// Step 1: Encode image with Swin Transformer
if (params.verbose) {
printf("Encoding image with Swin Transformer...\n");
}
// Create image batch
swin_image_f32 img = {
width,
height,
3,
std::vector<float>(image_data, image_data + width * height * 3)
};
swin_image_batch imgs = {1, &img};
// Encode image
std::vector<float> vision_embeddings(2048); // Adjust size based on model
if (!swin_image_batch_encode(vision_ctx, params.n_threads, &imgs, vision_embeddings.data())) {
fprintf(stderr, "Failed to encode image\n");
return "";
}
// Step 2: Pass embeddings through projector
// This would map vision embeddings to text embedding space
// Step 3: Generate text with mBART decoder
if (params.verbose) {
printf("Generating text with mBART decoder...\n");
}
// Create batch for text generation
llama_batch batch = llama_batch_init(params.context_size, 0, 1);
// Set up cross-attention with vision embeddings
// This requires the decoder to attend to encoder outputs
// Start with BOS token
llama_token bos_token = llama_token_get_bos(text_model);
batch.token[0] = bos_token;
batch.pos[0] = 0;
batch.n_seq_id[0] = 1;
batch.seq_id[0][0] = 0;
batch.n_tokens = 1;
// Decode initial token
if (llama_decode(text_ctx, batch) != 0) {
fprintf(stderr, "Failed to decode\n");
llama_batch_free(batch);
return "";
}
// Generate text autoregressively
std::vector<llama_token> generated_tokens;
generated_tokens.push_back(bos_token);
llama_token eos_token = llama_token_get_eos(text_model);
int max_tokens = params.context_size;
for (int i = 1; i < max_tokens; i++) {
// Get logits from last position
float* logits = llama_get_logits_ith(text_ctx, batch.n_tokens - 1);
// Sample next token
int n_vocab = llama_n_vocab(text_model);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
// Sample with top-k and top-p
int top_k = 40;
float top_p = 0.9f;
float temp = 0.8f;
llama_sample_top_k(text_ctx, &candidates_p, top_k, 1);
llama_sample_top_p(text_ctx, &candidates_p, top_p, 1);
llama_sample_temp(text_ctx, &candidates_p, temp);
llama_token new_token = llama_sample_token(text_ctx, &candidates_p);
// Check for EOS
if (new_token == eos_token) {
break;
}
generated_tokens.push_back(new_token);
// Add to batch for next iteration
batch.token[0] = new_token;
batch.pos[0] = i;
batch.n_tokens = 1;
if (llama_decode(text_ctx, batch) != 0) {
fprintf(stderr, "Failed to continue decoding\n");
break;
}
}
llama_batch_free(batch);
// Convert tokens to text
std::string result;
for (auto token : generated_tokens) {
std::string piece = llama_token_to_piece(text_ctx, token, true);
result += piece;
}
return result;
}
int main(int argc, char** argv) {
nougat_params params;
// Parse command line arguments
if (!parse_args(argc, argv, params)) {
print_usage(argv[0]);
return 1;
}
// Print banner
printf("\n");
printf("╔═══════════════════════════════════════════════════════╗\n");
printf("║ Nougat OCR - Document Understanding ║\n");
printf("║ Powered by Swin Transformer + mBART ║\n");
printf("╚═══════════════════════════════════════════════════════╝\n");
printf("\n");
printf("Input: %s\n", params.input_path.c_str());
printf("Output: %s\n", params.output_path.c_str());
printf("Format: %s\n", params.output_format.c_str());
printf("\n");
// Initialize backend
llama_backend_init();
// Load vision model (Swin Transformer)
printf("Loading vision model from %s...\n", params.vision_model.c_str());
struct swin_ctx* vision_ctx = swin_model_load(params.vision_model, params.verbose ? 2 : 1);
if (!vision_ctx) {
fprintf(stderr, "Failed to load vision model\n");
return 1;
}
// Load text model (mBART)
printf("Loading text model from %s...\n", params.text_model.c_str());
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = params.n_gpu_layers;
model_params.use_mmap = params.use_mmap;
struct llama_model* text_model = llama_load_model_from_file(
params.text_model.c_str(), model_params);
if (!text_model) {
fprintf(stderr, "Failed to load text model\n");
swin_free(vision_ctx);
return 1;
}
// Create text generation context
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = params.context_size;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads;
ctx_params.flash_attn = params.use_flash_attn;
struct llama_context* text_ctx = llama_new_context_with_model(text_model, ctx_params);
if (!text_ctx) {
fprintf(stderr, "Failed to create text context\n");
llama_free_model(text_model);
swin_free(vision_ctx);
return 1;
}
// Preprocess document
printf("Preprocessing document...\n");
float* preprocessed_data = nullptr;
int width, height, num_pages;
if (!nougat_preprocess_pipeline(
params.input_path.c_str(),
&preprocessed_data,
&width, &height, &num_pages)) {
fprintf(stderr, "Failed to preprocess document\n");
llama_free(text_ctx);
llama_free_model(text_model);
swin_free(vision_ctx);
return 1;
}
printf("Document info: %d pages, %dx%d pixels\n", num_pages, width, height);
// Limit pages if requested
if (params.max_pages > 0 && num_pages > params.max_pages) {
num_pages = params.max_pages;
printf("Processing first %d pages only\n", num_pages);
}
// Process each page
std::string full_output;
auto start_time = std::chrono::high_resolution_clock::now();
for (int page = 0; page < num_pages; page++) {
printf("\nProcessing page %d/%d...\n", page + 1, num_pages);
float* page_data = preprocessed_data + (page * width * height * 3);
std::string page_text = process_page(
vision_ctx, text_model, text_ctx,
page_data, width, height, params);
if (page_text.empty()) {
fprintf(stderr, "Warning: Failed to process page %d\n", page + 1);
continue;
}
// Add page separator for multi-page documents
if (page > 0) {
if (params.output_format == "markdown") {
full_output += "\n\n---\n\n";
} else if (params.output_format == "latex") {
full_output += "\n\\newpage\n\n";
} else {
full_output += "\n\n[Page " + std::to_string(page + 1) + "]\n\n";
}
}
full_output += page_text;
// Save intermediate results if requested
if (params.save_intermediate) {
std::string intermediate_file = params.output_path + ".page" +
std::to_string(page + 1) + ".tmp";
std::ofstream tmp_out(intermediate_file);
tmp_out << page_text;
tmp_out.close();
}
}
auto end_time = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::seconds>(end_time - start_time);
// Save final output
printf("\nSaving output to %s...\n", params.output_path.c_str());
std::ofstream output_file(params.output_path);
if (!output_file) {
fprintf(stderr, "Failed to open output file\n");
} else {
// Add format-specific headers/footers
if (params.output_format == "latex") {
output_file << "\\documentclass{article}\n";
output_file << "\\usepackage{amsmath}\n";
output_file << "\\usepackage{graphicx}\n";
output_file << "\\begin{document}\n\n";
}
output_file << full_output;
if (params.output_format == "latex") {
output_file << "\n\n\\end{document}\n";
}
output_file.close();
}
// Print statistics
printf("\n");
printf("╔════════════════════════════════════╗\n");
printf("║ OCR Complete! ║\n");
printf("╠════════════════════════════════════╣\n");
printf("║ Pages processed: %-17d ║\n", num_pages);
printf("║ Time taken: %-17lds║\n", duration.count());
printf("║ Output size: %-17zd ║\n", full_output.size());
printf("╚════════════════════════════════════╝\n");
// Cleanup
nougat_preprocess_cleanup(preprocessed_data);
llama_free(text_ctx);
llama_free_model(text_model);
swin_free(vision_ctx);
llama_backend_free();
return 0;
}