diff --git a/examples/model-conversion/scripts/causal/compare-logits.py b/examples/model-conversion/scripts/causal/compare-logits.py index 2744789099..894302c69e 100755 --- a/examples/model-conversion/scripts/causal/compare-logits.py +++ b/examples/model-conversion/scripts/causal/compare-logits.py @@ -1,10 +1,13 @@ #!/usr/bin/env python3 -import numpy as np import sys -import os +import numpy as np from pathlib import Path +# Add utils directory to path for direct script execution +sys.path.insert(0, str(Path(__file__).parent.parent / "utils")) +from common import get_model_name_from_env_path # type: ignore[import-not-found] + def quick_logits_check(pytorch_file, llamacpp_file): """Lightweight sanity check before NMSE""" @@ -35,20 +38,13 @@ def quick_logits_check(pytorch_file, llamacpp_file): return True def main(): - model_path = os.getenv('MODEL_PATH') - if not model_path: - print("Error: MODEL_PATH environment variable not set") - sys.exit(1) - - if not os.path.exists(model_path): - print(f"Error: Model file not found: {model_path}") - sys.exit(1) - - model_name = os.path.basename(model_path) + model_name = get_model_name_from_env_path('MODEL_PATH') data_dir = Path("data") - pytorch_file = data_dir / f"pytorch-{model_name}.bin" - llamacpp_file = data_dir / f"llamacpp-{model_name}.bin" + + llamacpp_model_name = get_model_name_from_env_path('CONVERTED_MODEL') + print(f"Using converted model: {llamacpp_model_name}") + llamacpp_file = data_dir / f"llamacpp-{llamacpp_model_name}.bin" if not pytorch_file.exists(): print(f"Error: PyTorch logits file not found: {pytorch_file}") diff --git a/examples/model-conversion/scripts/utils/__init__.py b/examples/model-conversion/scripts/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/model-conversion/scripts/utils/check-nmse.py b/examples/model-conversion/scripts/utils/check-nmse.py index 939e3153cc..83f63f9ff3 100755 --- a/examples/model-conversion/scripts/utils/check-nmse.py +++ b/examples/model-conversion/scripts/utils/check-nmse.py @@ -5,6 +5,7 @@ import sys import os import argparse from pathlib import Path +from common import get_model_name_from_env_path # type: ignore[import-not-found] def calculate_nmse(reference, test): mse = np.mean((test - reference) ** 2) @@ -67,11 +68,13 @@ def main(): parser.add_argument('-m', '--model-path', required=True, help='Path to the model directory') args = parser.parse_args() - model_name = os.path.basename(args.model_path) + model_name = get_model_name_from_env_path('MODEL_PATH') data_dir = Path("data") pytorch_file = data_dir / f"pytorch-{model_name}.bin" - llamacpp_file = data_dir / f"llamacpp-{model_name}.bin" + + llamacpp_model_name = get_model_name_from_env_path('CONVERTED_MODEL') + llamacpp_file = data_dir / f"llamacpp-{llamacpp_model_name}.bin" print(f"Model name: {model_name}") print(f"PyTorch logits file: {pytorch_file}") diff --git a/examples/model-conversion/scripts/utils/common.py b/examples/model-conversion/scripts/utils/common.py new file mode 100644 index 0000000000..945f9a1a1d --- /dev/null +++ b/examples/model-conversion/scripts/utils/common.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 + +import os +import sys + +def get_model_name_from_env_path(env_path_name): + model_path = os.getenv(env_path_name) + if not model_path: + print(f"Error: {env_path_name} environment variable not set") + sys.exit(1) + + if not os.path.exists(model_path): + print(f"Error: Model file not found: {model_path}") + sys.exit(1) + + name = os.path.basename(os.path.normpath(model_path)) + if name.endswith(".gguf"): + name = name[:-5] + + return name diff --git a/pyrightconfig.json b/pyrightconfig.json index 5320fe5864..a7bc007bdc 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -1,5 +1,5 @@ { - "extraPaths": ["gguf-py"], + "extraPaths": ["gguf-py", "examples/model-conversion/scripts"], "pythonVersion": "3.9", "pythonPlatform": "All", "reportUnusedImport": "warning",