#!/usr/bin/env python3 import argparse import json import os import re import sys from pathlib import Path from typing import Optional from safetensors import safe_open MODEL_SAFETENSORS_FILE = "model.safetensors" MODEL_SAFETENSORS_INDEX = "model.safetensors.index.json" def get_weight_map(model_path: Path) -> Optional[dict[str, str]]: index_file = model_path / MODEL_SAFETENSORS_INDEX if index_file.exists(): with open(index_file, 'r') as f: index = json.load(f) return index.get("weight_map", {}) return None def get_all_tensor_names(model_path: Path) -> list[str]: weight_map = get_weight_map(model_path) if weight_map is not None: return list(weight_map.keys()) single_file = model_path / MODEL_SAFETENSORS_FILE if single_file.exists(): try: with safe_open(single_file, framework="pt", device="cpu") as f: return list(f.keys()) except Exception as e: print(f"Error reading {single_file}: {e}") sys.exit(1) print(f"Error: No safetensors files found in {model_path}") sys.exit(1) def find_tensor_file(model_path: Path, tensor_name: str) -> Optional[str]: weight_map = get_weight_map(model_path) if weight_map is not None: return weight_map.get(tensor_name) single_file = model_path / MODEL_SAFETENSORS_FILE if single_file.exists(): return single_file.name return None def normalize_tensor_name(tensor_name: str) -> str: normalized = re.sub(r'\.\d+\.', '.#.', tensor_name) normalized = re.sub(r'\.\d+$', '.#', normalized) return normalized def list_all_tensors(model_path: Path, unique: bool = False): tensor_names = get_all_tensor_names(model_path) if unique: seen = set() for tensor_name in sorted(tensor_names): normalized = normalize_tensor_name(tensor_name) if normalized not in seen: seen.add(normalized) print(normalized) else: for tensor_name in sorted(tensor_names): print(tensor_name) def print_tensor_info(model_path: Path, tensor_name: str): tensor_file = find_tensor_file(model_path, tensor_name) if tensor_file is None: print(f"Error: Could not find tensor '{tensor_name}' in model index") print(f"Model path: {model_path}") sys.exit(1) file_path = model_path / tensor_file try: with safe_open(file_path, framework="pt", device="cpu") as f: if tensor_name in f.keys(): tensor_slice = f.get_slice(tensor_name) shape = tensor_slice.get_shape() print(f"Tensor: {tensor_name}") print(f"File: {tensor_file}") print(f"Shape: {shape}") else: print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}") sys.exit(1) except FileNotFoundError: print(f"Error: The file '{file_path}' was not found.") sys.exit(1) except Exception as e: print(f"An error occurred: {e}") sys.exit(1) def main(): parser = argparse.ArgumentParser( description="Print tensor information from a safetensors model" ) parser.add_argument( "tensor_name", nargs="?", # optional (if --list is used for example) help="Name of the tensor to inspect" ) parser.add_argument( "-m", "--model-path", type=Path, help="Path to the model directory (default: MODEL_PATH environment variable)" ) parser.add_argument( "-l", "--list", action="store_true", help="List unique tensor patterns in the model (layer numbers replaced with #)" ) args = parser.parse_args() model_path = args.model_path if model_path is None: model_path_str = os.environ.get("MODEL_PATH") if model_path_str is None: print("Error: --model-path not provided and MODEL_PATH environment variable not set") sys.exit(1) model_path = Path(model_path_str) if not model_path.exists(): print(f"Error: Model path does not exist: {model_path}") sys.exit(1) if not model_path.is_dir(): print(f"Error: Model path is not a directory: {model_path}") sys.exit(1) if args.list: list_all_tensors(model_path, unique=True) else: if args.tensor_name is None: print("Error: tensor_name is required when not using --list") sys.exit(1) print_tensor_info(model_path, args.tensor_name) if __name__ == "__main__": main()