From 717865b62bc8d72ce2823b54f6db3e9da36911c3 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Sun, 14 Dec 2025 18:24:49 +0100 Subject: [PATCH] Extend run-org-model.py, add (a) batching (b) loading prompt from file (c) multimodal capacity --- .../scripts/causal/run-org-model.py | 41 +++++++++++++++---- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/examples/model-conversion/scripts/causal/run-org-model.py b/examples/model-conversion/scripts/causal/run-org-model.py index da1132c003..229ea46550 100755 --- a/examples/model-conversion/scripts/causal/run-org-model.py +++ b/examples/model-conversion/scripts/causal/run-org-model.py @@ -5,7 +5,7 @@ import os import importlib from pathlib import Path -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig import torch import numpy as np @@ -116,11 +116,11 @@ def debug_hook(name): def fn(_m, input, output): if isinstance(input, torch.Tensor): summarize(input, name + "_in") - elif isinstance(input, (tuple, list)) and isinstance(input[0], torch.Tensor): + elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor): summarize(input[0], name + "_in") if isinstance(output, torch.Tensor): summarize(output, name + "_out") - elif isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor): + elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor): summarize(output[0], name + "_out") return fn @@ -130,6 +130,7 @@ unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME") parser = argparse.ArgumentParser(description="Process model with specified path") parser.add_argument("--model-path", "-m", help="Path to the model") +parser.add_argument("--prompt-file", "-f", help="Optional prompt file", required=False) args = parser.parse_args() model_path = os.environ.get("MODEL_PATH", args.model_path) @@ -142,8 +143,13 @@ if model_path is None: print("Loading model and tokenizer using AutoTokenizer:", model_path) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) +multimodal = False +full_config = config print("Model type: ", config.model_type) +if "vocab_size" not in config and "text_config" in config: + config = config.text_config + multimodal = True print("Vocab size: ", config.vocab_size) print("Hidden size: ", config.hidden_size) print("Number of layers: ", config.num_hidden_layers) @@ -169,9 +175,14 @@ if unreleased_model_name: print(f"Failed to import or load model: {e}") exit(1) else: - model = AutoModelForCausalLM.from_pretrained( - model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config - ) + if multimodal: + model = AutoModelForImageTextToText.from_pretrained( + model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=full_config + ) + else: + model = AutoModelForCausalLM.from_pretrained( + model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config + ) for name, module in model.named_modules(): if len(list(module.children())) == 0: # only leaf modules @@ -185,7 +196,10 @@ model_name = os.path.basename(model_path) print(f"Model class: {model.__class__.__name__}") device = next(model.parameters()).device -if os.getenv("MODEL_TESTING_PROMPT"): +if args.prompt_file: + with open(args.prompt_file, encoding='utf-8') as f: + prompt = f.read() +elif os.getenv("MODEL_TESTING_PROMPT"): prompt = os.getenv("MODEL_TESTING_PROMPT") else: prompt = "Hello, my name is" @@ -195,9 +209,18 @@ print(f"Input tokens: {input_ids}") print(f"Input text: {repr(prompt)}") print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}") +batch_size = 512 + with torch.no_grad(): - outputs = model(input_ids.to(model.device)) - logits = outputs.logits + past = None + outputs = None + for i in range(0, input_ids.size(1), batch_size): + print(f"Processing chunk with tokens {i} to {i + batch_size}") + chunk = input_ids[:, i:i + batch_size] + outputs = model(chunk.to(model.device), past_key_values=past, use_cache=True) + past = outputs.past_key_values + + logits = outputs.logits # type: ignore # Extract logits for the last token (next token prediction) last_logits = logits[0, -1, :].float().cpu().numpy()