From 6ab881b7c3abd7c1aac0d7f3a00bde68d9cfd484 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 4 Feb 2026 10:40:53 +0100 Subject: [PATCH] model-conversion : add tensor-info.py utility (#18954) This commit adds a new python script that can be used to print tensors information from a tensor in a safetensors model. The motivation for this is that during model conversion work it can sometimes be useful to verify the shape of tensors in the original model. While it is possible to print the tensors when loading the model this can be slow when working with larger models. With this script it is possible to quickly query tensor shapes. Example usage: ```console (venv) $ ./scripts/utils/tensor-info.py --help usage: tensor-info.py [-h] [-m MODEL_PATH] [-l] [tensor_name] Print tensor information from a safetensors model positional arguments: tensor_name Name of the tensor to inspect options: -h, --help show this help message and exit -m MODEL_PATH, --model-path MODEL_PATH Path to the model directory (default: MODEL_PATH environment variable) -l, --list List unique tensor patterns in the model (layer numbers replaced with #) ``` Listing tensor names: ```console (venv) $ ./scripts/utils/tensor-info.py -m ~/work/ai/models/google/embeddinggemma-300m -l embed_tokens.weight layers.#.input_layernorm.weight layers.#.mlp.down_proj.weight layers.#.mlp.gate_proj.weight layers.#.mlp.up_proj.weight layers.#.post_attention_layernorm.weight layers.#.post_feedforward_layernorm.weight layers.#.pre_feedforward_layernorm.weight layers.#.self_attn.k_norm.weight layers.#.self_attn.k_proj.weight layers.#.self_attn.o_proj.weight layers.#.self_attn.q_norm.weight layers.#.self_attn.q_proj.weight layers.#.self_attn.v_proj.weight norm.weight ``` Printing a specific tensor's information: ```console (venv) $ ./scripts/utils/tensor-info.py -m ~/work/ai/models/google/embeddinggemma-300m layers.0.input_layernorm.weight Tensor: layers.0.input_layernorm.weight File: model.safetensors Shape: [768] ``` --- .../scripts/utils/tensor-info.py | 159 ++++++++++++++++++ 1 file changed, 159 insertions(+) create mode 100755 examples/model-conversion/scripts/utils/tensor-info.py diff --git a/examples/model-conversion/scripts/utils/tensor-info.py b/examples/model-conversion/scripts/utils/tensor-info.py new file mode 100755 index 0000000000..12a3430b49 --- /dev/null +++ b/examples/model-conversion/scripts/utils/tensor-info.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 + +import argparse +import json +import os +import re +import sys +from pathlib import Path +from typing import Optional +from safetensors import safe_open + + +MODEL_SAFETENSORS_FILE = "model.safetensors" +MODEL_SAFETENSORS_INDEX = "model.safetensors.index.json" + + +def get_weight_map(model_path: Path) -> Optional[dict[str, str]]: + index_file = model_path / MODEL_SAFETENSORS_INDEX + + if index_file.exists(): + with open(index_file, 'r') as f: + index = json.load(f) + return index.get("weight_map", {}) + + return None + + +def get_all_tensor_names(model_path: Path) -> list[str]: + weight_map = get_weight_map(model_path) + + if weight_map is not None: + return list(weight_map.keys()) + + single_file = model_path / MODEL_SAFETENSORS_FILE + if single_file.exists(): + try: + with safe_open(single_file, framework="pt", device="cpu") as f: + return list(f.keys()) + except Exception as e: + print(f"Error reading {single_file}: {e}") + sys.exit(1) + + print(f"Error: No safetensors files found in {model_path}") + sys.exit(1) + + +def find_tensor_file(model_path: Path, tensor_name: str) -> Optional[str]: + weight_map = get_weight_map(model_path) + + if weight_map is not None: + return weight_map.get(tensor_name) + + single_file = model_path / MODEL_SAFETENSORS_FILE + if single_file.exists(): + return single_file.name + + return None + + +def normalize_tensor_name(tensor_name: str) -> str: + normalized = re.sub(r'\.\d+\.', '.#.', tensor_name) + normalized = re.sub(r'\.\d+$', '.#', normalized) + return normalized + + +def list_all_tensors(model_path: Path, unique: bool = False): + tensor_names = get_all_tensor_names(model_path) + + if unique: + seen = set() + for tensor_name in sorted(tensor_names): + normalized = normalize_tensor_name(tensor_name) + if normalized not in seen: + seen.add(normalized) + print(normalized) + else: + for tensor_name in sorted(tensor_names): + print(tensor_name) + + +def print_tensor_info(model_path: Path, tensor_name: str): + tensor_file = find_tensor_file(model_path, tensor_name) + + if tensor_file is None: + print(f"Error: Could not find tensor '{tensor_name}' in model index") + print(f"Model path: {model_path}") + sys.exit(1) + + file_path = model_path / tensor_file + + try: + with safe_open(file_path, framework="pt", device="cpu") as f: + if tensor_name in f.keys(): + tensor_slice = f.get_slice(tensor_name) + shape = tensor_slice.get_shape() + print(f"Tensor: {tensor_name}") + print(f"File: {tensor_file}") + print(f"Shape: {shape}") + else: + print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}") + sys.exit(1) + + except FileNotFoundError: + print(f"Error: The file '{file_path}' was not found.") + sys.exit(1) + except Exception as e: + print(f"An error occurred: {e}") + sys.exit(1) + + +def main(): + parser = argparse.ArgumentParser( + description="Print tensor information from a safetensors model" + ) + parser.add_argument( + "tensor_name", + nargs="?", # optional (if --list is used for example) + help="Name of the tensor to inspect" + ) + parser.add_argument( + "-m", "--model-path", + type=Path, + help="Path to the model directory (default: MODEL_PATH environment variable)" + ) + parser.add_argument( + "-l", "--list", + action="store_true", + help="List unique tensor patterns in the model (layer numbers replaced with #)" + ) + + args = parser.parse_args() + + model_path = args.model_path + if model_path is None: + model_path_str = os.environ.get("MODEL_PATH") + if model_path_str is None: + print("Error: --model-path not provided and MODEL_PATH environment variable not set") + sys.exit(1) + model_path = Path(model_path_str) + + if not model_path.exists(): + print(f"Error: Model path does not exist: {model_path}") + sys.exit(1) + + if not model_path.is_dir(): + print(f"Error: Model path is not a directory: {model_path}") + sys.exit(1) + + if args.list: + list_all_tensors(model_path, unique=True) + else: + if args.tensor_name is None: + print("Error: tensor_name is required when not using --list") + sys.exit(1) + print_tensor_info(model_path, args.tensor_name) + + +if __name__ == "__main__": + main()