This commit is contained in:
Sigbjørn Skjæret 2025-12-17 02:24:05 +01:00 committed by GitHub
commit cfca1d4ebb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 3 deletions

View File

@ -189,10 +189,10 @@ class ModelBase:
return tensors return tensors
prefix = "model" if not self.is_mistral_format else "consolidated" prefix = "model" if not self.is_mistral_format else "consolidated"
part_names: set[str] = set(ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")) part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
is_safetensors: bool = len(part_names) > 0 is_safetensors: bool = len(part_names) > 0
if not is_safetensors: if not is_safetensors:
part_names = set(ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")) part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
tensor_names_from_index: set[str] = set() tensor_names_from_index: set[str] = set()
@ -209,7 +209,8 @@ class ModelBase:
if weight_map is None or not isinstance(weight_map, dict): if weight_map is None or not isinstance(weight_map, dict):
raise ValueError(f"Can't load 'weight_map' from {index_name!r}") raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
tensor_names_from_index.update(weight_map.keys()) tensor_names_from_index.update(weight_map.keys())
part_names |= set(weight_map.values()) part_dict: dict[str, None] = dict.fromkeys(list(weight_map.values()) + part_names, None)
part_names = list(part_dict.keys())
else: else:
weight_map = {} weight_map = {}
else: else: