This commit is contained in:
Piotr Wilkin (ilintar) 2025-12-17 05:51:06 +02:00 committed by GitHub
commit c53f82da8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 32 additions and 9 deletions

View File

@ -5,7 +5,7 @@ import os
import importlib import importlib
from pathlib import Path from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig
import torch import torch
import numpy as np import numpy as np
@ -116,11 +116,11 @@ def debug_hook(name):
def fn(_m, input, output): def fn(_m, input, output):
if isinstance(input, torch.Tensor): if isinstance(input, torch.Tensor):
summarize(input, name + "_in") summarize(input, name + "_in")
elif isinstance(input, (tuple, list)) and isinstance(input[0], torch.Tensor): elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
summarize(input[0], name + "_in") summarize(input[0], name + "_in")
if isinstance(output, torch.Tensor): if isinstance(output, torch.Tensor):
summarize(output, name + "_out") summarize(output, name + "_out")
elif isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor): elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
summarize(output[0], name + "_out") summarize(output[0], name + "_out")
return fn return fn
@ -130,6 +130,7 @@ unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
parser = argparse.ArgumentParser(description="Process model with specified path") parser = argparse.ArgumentParser(description="Process model with specified path")
parser.add_argument("--model-path", "-m", help="Path to the model") parser.add_argument("--model-path", "-m", help="Path to the model")
parser.add_argument("--prompt-file", "-f", help="Optional prompt file", required=False)
args = parser.parse_args() args = parser.parse_args()
model_path = os.environ.get("MODEL_PATH", args.model_path) model_path = os.environ.get("MODEL_PATH", args.model_path)
@ -142,8 +143,13 @@ if model_path is None:
print("Loading model and tokenizer using AutoTokenizer:", model_path) print("Loading model and tokenizer using AutoTokenizer:", model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
multimodal = False
full_config = config
print("Model type: ", config.model_type) print("Model type: ", config.model_type)
if "vocab_size" not in config and "text_config" in config:
config = config.text_config
multimodal = True
print("Vocab size: ", config.vocab_size) print("Vocab size: ", config.vocab_size)
print("Hidden size: ", config.hidden_size) print("Hidden size: ", config.hidden_size)
print("Number of layers: ", config.num_hidden_layers) print("Number of layers: ", config.num_hidden_layers)
@ -168,6 +174,11 @@ if unreleased_model_name:
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
print(f"Failed to import or load model: {e}") print(f"Failed to import or load model: {e}")
exit(1) exit(1)
else:
if multimodal:
model = AutoModelForImageTextToText.from_pretrained(
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=full_config
)
else: else:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
@ -185,7 +196,10 @@ model_name = os.path.basename(model_path)
print(f"Model class: {model.__class__.__name__}") print(f"Model class: {model.__class__.__name__}")
device = next(model.parameters()).device device = next(model.parameters()).device
if os.getenv("MODEL_TESTING_PROMPT"): if args.prompt_file:
with open(args.prompt_file, encoding='utf-8') as f:
prompt = f.read()
elif os.getenv("MODEL_TESTING_PROMPT"):
prompt = os.getenv("MODEL_TESTING_PROMPT") prompt = os.getenv("MODEL_TESTING_PROMPT")
else: else:
prompt = "Hello, my name is" prompt = "Hello, my name is"
@ -195,9 +209,18 @@ print(f"Input tokens: {input_ids}")
print(f"Input text: {repr(prompt)}") print(f"Input text: {repr(prompt)}")
print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}") print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
batch_size = 512
with torch.no_grad(): with torch.no_grad():
outputs = model(input_ids.to(model.device)) past = None
logits = outputs.logits outputs = None
for i in range(0, input_ids.size(1), batch_size):
print(f"Processing chunk with tokens {i} to {i + batch_size}")
chunk = input_ids[:, i:i + batch_size]
outputs = model(chunk.to(model.device), past_key_values=past, use_cache=True)
past = outputs.past_key_values
logits = outputs.logits # type: ignore
# Extract logits for the last token (next token prediction) # Extract logits for the last token (next token prediction)
last_logits = logits[0, -1, :].float().cpu().numpy() last_logits = logits[0, -1, :].float().cpu().numpy()