#!/usr/bin/env python3 # generated by Claude """ Script to inspect SafeTensors model files and print tensor information. """ import json from safetensors import safe_open import os from pathlib import Path def inspect_safetensors_model(model_dir="."): """Inspect all SafeTensors files in the model directory.""" # First, let's read the index file to see the file structure index_file = Path(model_dir) / "model.safetensors.index.json" if index_file.exists(): with open(index_file, 'r') as f: index_data = json.load(f) print("=== Model Structure ===") print(f"Total parameters: {index_data.get('metadata', {}).get('total_size', 'Unknown')}") print() # Get all safetensor files safetensor_files = set(index_data.get('weight_map', {}).values()) else: # If no index file, look for safetensor files directly safetensor_files = [f for f in os.listdir(model_dir) if f.endswith('.safetensors')] # Sort files for consistent output safetensor_files = sorted(safetensor_files) print("=== Tensor Information ===") print(f"{'Tensor Name':<50} {'Shape':<25} {'Data Type':<15} {'File'}") print("-" * 110) total_tensors = 0 for filename in safetensor_files: filepath = Path(model_dir) / filename if not filepath.exists(): continue print(f"\n--- {filename} ---") # Open and inspect the safetensor file with safe_open(filepath, framework="pt") as f: # Use PyTorch framework for better dtype support tensor_names = f.keys() for tensor_name in sorted(tensor_names): # Get tensor metadata without loading the full tensor tensor_slice = f.get_slice(tensor_name) shape = tensor_slice.get_shape() dtype = tensor_slice.get_dtype() shape_str = str(tuple(shape)) dtype_str = str(dtype) print(f"{tensor_name:<50} {shape_str:<25} {dtype_str:<15} {filename}") total_tensors += 1 print(f"\nTotal tensors found: {total_tensors}") def main(): import argparse parser = argparse.ArgumentParser(description="Inspect SafeTensors model files") parser.add_argument("--model-dir", "-d", default=".", help="Directory containing the model files (default: current directory)") parser.add_argument("--summary", "-s", action="store_true", help="Show only summary statistics") args = parser.parse_args() if args.summary: print_summary_only(args.model_dir) else: inspect_safetensors_model(args.model_dir) def print_summary_only(model_dir="."): """Print only summary statistics.""" safetensor_files = [f for f in os.listdir(model_dir) if f.endswith('.safetensors')] total_tensors = 0 dtype_counts = {} total_params = 0 for filename in sorted(safetensor_files): filepath = Path(model_dir) / filename if not filepath.exists(): continue with safe_open(filepath, framework="pt") as f: # Use PyTorch framework for tensor_name in f.keys(): tensor_slice = f.get_slice(tensor_name) shape = tensor_slice.get_shape() dtype = tensor_slice.get_dtype() total_tensors += 1 dtype_str = str(dtype) dtype_counts[dtype_str] = dtype_counts.get(dtype_str, 0) + 1 # Calculate parameter count param_count = 1 for dim in shape: param_count *= dim total_params += param_count print("=== Model Summary ===") print(f"Total tensors: {total_tensors}") print(f"Total parameters: {total_params:,}") print(f"Data type distribution:") for dtype, count in sorted(dtype_counts.items()): print(f" {dtype}: {count} tensors") if __name__ == "__main__": main()