From 0a271d82b46577d955b6c5c8020aab3b4d21057d Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 19 Dec 2025 08:43:16 +0100 Subject: [PATCH] model-conversion : add verbose flag in run-org-model.py (#18194) This commit adds a --verbose flag to the run-org-model.py script to enable or disable detailed debug output, such as input and output tensors for each layer. Debug utilities (summarize, debug_hook, setup_rope_debug) have been moved to utils/common.py. The motivation for this is that the detailed debug output can be useful for diagnosing issues with model conversion or execution, but it can also produce a large amount of output that may not always be needed. The script will also be further cleaned/refactored in follow-up commits. --- .../scripts/causal/run-org-model.py | 139 +++--------------- .../model-conversion/scripts/utils/common.py | 130 ++++++++++++++++ 2 files changed, 147 insertions(+), 122 deletions(-) diff --git a/examples/model-conversion/scripts/causal/run-org-model.py b/examples/model-conversion/scripts/causal/run-org-model.py index 229ea46550..14bb12fe68 100755 --- a/examples/model-conversion/scripts/causal/run-org-model.py +++ b/examples/model-conversion/scripts/causal/run-org-model.py @@ -2,135 +2,22 @@ import argparse import os +import sys import importlib from pathlib import Path +# Add parent directory to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig import torch import numpy as np - -### If you want to dump RoPE activations, apply this monkey patch to the model -### class from Transformers that you are running (replace apertus.modeling_apertus -### with the proper package and class for your model -### === START ROPE DEBUG === -# from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb - -# orig_rope = apply_rotary_pos_emb -# torch.set_printoptions(threshold=float('inf')) -# torch.set_printoptions(precision=6, sci_mode=False) - -# def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): -# # log inputs -# summarize(q, "RoPE.q_in") -# summarize(k, "RoPE.k_in") - -# # call original -# q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim) - -# # log outputs -# summarize(q_out, "RoPE.q_out") -# summarize(k_out, "RoPE.k_out") - -# return q_out, k_out - -# # Patch it -# import transformers.models.apertus.modeling_apertus as apertus_mod # noqa: E402 -# apertus_mod.apply_rotary_pos_emb = debug_rope -### == END ROPE DEBUG === - - -def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3): - """ - Print a tensor in llama.cpp debug style. - - Supports: - - 2D tensors (seq, hidden) - - 3D tensors (batch, seq, hidden) - - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head - - Shows first and last max_vals of each vector per sequence position. - """ - t = tensor.detach().to(torch.float32).cpu() - - # Determine dimensions - if t.ndim == 3: - _, s, _ = t.shape - elif t.ndim == 2: - _, s = 1, t.shape[0] - t = t.unsqueeze(0) - elif t.ndim == 4: - _, s, _, _ = t.shape - else: - print(f"Skipping tensor due to unsupported dimensions: {t.ndim}") - return - - ten_shape = t.shape - - print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}") - print(" [") - print(" [") - - # Determine indices for first and last sequences - first_indices = list(range(min(s, max_seq))) - last_indices = list(range(max(0, s - max_seq), s)) - - # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq - has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s) - - # Combine indices - if has_overlap: - # If there's overlap, just use the combined unique indices - indices = sorted(list(set(first_indices + last_indices))) - separator_index = None - else: - # If no overlap, we'll add a separator between first and last sequences - indices = first_indices + last_indices - separator_index = len(first_indices) - - for i, si in enumerate(indices): - # Add separator if needed - if separator_index is not None and i == separator_index: - print(" ...") - - # Extract appropriate slice - vec = t[0, si] - if vec.ndim == 2: # 4D case: flatten heads × dim_per_head - flat = vec.flatten().tolist() - else: # 2D or 3D case - flat = vec.tolist() - - # First and last slices - first = flat[:max_vals] - last = flat[-max_vals:] if len(flat) >= max_vals else flat - first_str = ", ".join(f"{v:12.4f}" for v in first) - last_str = ", ".join(f"{v:12.4f}" for v in last) - - print(f" [{first_str}, ..., {last_str}]") - - print(" ],") - print(" ]") - print(f" sum = {t.sum().item():.6f}\n") - - -def debug_hook(name): - def fn(_m, input, output): - if isinstance(input, torch.Tensor): - summarize(input, name + "_in") - elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor): - summarize(input[0], name + "_in") - if isinstance(output, torch.Tensor): - summarize(output, name + "_out") - elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor): - summarize(output[0], name + "_out") - - return fn - - -unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME") +from utils.common import debug_hook parser = argparse.ArgumentParser(description="Process model with specified path") parser.add_argument("--model-path", "-m", help="Path to the model") parser.add_argument("--prompt-file", "-f", help="Optional prompt file", required=False) +parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose debug output") args = parser.parse_args() model_path = os.environ.get("MODEL_PATH", args.model_path) @@ -139,6 +26,12 @@ if model_path is None: "Model path must be specified either via --model-path argument or MODEL_PATH environment variable" ) +### If you want to dump RoPE activations, uncomment the following lines: +### === START ROPE DEBUG === +# from utils.common import setup_rope_debug +# setup_rope_debug("transformers.models.apertus.modeling_apertus") +### == END ROPE DEBUG === + print("Loading model and tokenizer using AutoTokenizer:", model_path) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) @@ -156,6 +49,7 @@ print("Number of layers: ", config.num_hidden_layers) print("BOS token id: ", config.bos_token_id) print("EOS token id: ", config.eos_token_id) +unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME") if unreleased_model_name: model_name_lower = unreleased_model_name.lower() unreleased_module_path = ( @@ -184,9 +78,10 @@ else: model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config ) -for name, module in model.named_modules(): - if len(list(module.children())) == 0: # only leaf modules - module.register_forward_hook(debug_hook(name)) +if args.verbose: + for name, module in model.named_modules(): + if len(list(module.children())) == 0: # only leaf modules + module.register_forward_hook(debug_hook(name)) model_name = os.path.basename(model_path) # Printing the Model class to allow for easier debugging. This can be useful diff --git a/examples/model-conversion/scripts/utils/common.py b/examples/model-conversion/scripts/utils/common.py index 945f9a1a1d..7595d0410e 100644 --- a/examples/model-conversion/scripts/utils/common.py +++ b/examples/model-conversion/scripts/utils/common.py @@ -2,6 +2,8 @@ import os import sys +import torch + def get_model_name_from_env_path(env_path_name): model_path = os.getenv(env_path_name) @@ -18,3 +20,131 @@ def get_model_name_from_env_path(env_path_name): name = name[:-5] return name + + +def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3): + """ + Print a tensor in llama.cpp debug style. + + Supports: + - 2D tensors (seq, hidden) + - 3D tensors (batch, seq, hidden) + - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head + + Shows first and last max_vals of each vector per sequence position. + """ + t = tensor.detach().to(torch.float32).cpu() + + # Determine dimensions + if t.ndim == 3: + _, s, _ = t.shape + elif t.ndim == 2: + _, s = 1, t.shape[0] + t = t.unsqueeze(0) + elif t.ndim == 4: + _, s, _, _ = t.shape + else: + print(f"Skipping tensor due to unsupported dimensions: {t.ndim}") + return + + ten_shape = t.shape + + print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}") + print(" [") + print(" [") + + # Determine indices for first and last sequences + first_indices = list(range(min(s, max_seq))) + last_indices = list(range(max(0, s - max_seq), s)) + + # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq + has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s) + + # Combine indices + if has_overlap: + # If there's overlap, just use the combined unique indices + indices = sorted(list(set(first_indices + last_indices))) + separator_index = None + else: + # If no overlap, we'll add a separator between first and last sequences + indices = first_indices + last_indices + separator_index = len(first_indices) + + for i, si in enumerate(indices): + # Add separator if needed + if separator_index is not None and i == separator_index: + print(" ...") + + # Extract appropriate slice + vec = t[0, si] + if vec.ndim == 2: # 4D case: flatten heads × dim_per_head + flat = vec.flatten().tolist() + else: # 2D or 3D case + flat = vec.tolist() + + # First and last slices + first = flat[:max_vals] + last = flat[-max_vals:] if len(flat) >= max_vals else flat + first_str = ", ".join(f"{v:12.4f}" for v in first) + last_str = ", ".join(f"{v:12.4f}" for v in last) + + print(f" [{first_str}, ..., {last_str}]") + + print(" ],") + print(" ]") + print(f" sum = {t.sum().item():.6f}\n") + + +def debug_hook(name): + def fn(_m, input, output): + if isinstance(input, torch.Tensor): + summarize(input, name + "_in") + elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor): + summarize(input[0], name + "_in") + if isinstance(output, torch.Tensor): + summarize(output, name + "_out") + elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor): + summarize(output[0], name + "_out") + + return fn + + +def setup_rope_debug(model_module_path: str, function_name: str = "apply_rotary_pos_emb"): + """ + Apply monkey patch to dump RoPE activations for debugging. + + Args: + model_module_path: Path to the model module (e.g., "transformers.models.apertus.modeling_apertus") + function_name: Name of the RoPE function to patch (default: "apply_rotary_pos_emb") + + Example: + from utils.common import setup_rope_debug + setup_rope_debug("transformers.models.apertus.modeling_apertus") + """ + import importlib + + # Import the module and get the original function + module = importlib.import_module(model_module_path) + orig_rope = getattr(module, function_name) + + # Set torch print options for better debugging + torch.set_printoptions(threshold=float('inf')) + torch.set_printoptions(precision=6, sci_mode=False) + + def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + # log inputs + summarize(q, "RoPE.q_in") + summarize(k, "RoPE.k_in") + + # call original + q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim) + + # log outputs + summarize(q_out, "RoPE.q_out") + summarize(k_out, "RoPE.k_out") + + return q_out, k_out + + # Patch it + setattr(module, function_name, debug_rope) + print(f"RoPE debug patching applied to {model_module_path}.{function_name}")