Use weights_only for loading (#3427)
Co-authored-by: Manuel Schmid <9307310+mashb1t@users.noreply.github.com>
This commit is contained in:
parent
1a53e0676a
commit
da3d4d006f
|
|
@ -216,9 +216,9 @@ def is_url(url_or_filename):
|
||||||
def load_checkpoint(model,url_or_filename):
|
def load_checkpoint(model,url_or_filename):
|
||||||
if is_url(url_or_filename):
|
if is_url(url_or_filename):
|
||||||
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
||||||
checkpoint = torch.load(cached_file, map_location='cpu')
|
checkpoint = torch.load(cached_file, map_location='cpu', weights_only=True)
|
||||||
elif os.path.isfile(url_or_filename):
|
elif os.path.isfile(url_or_filename):
|
||||||
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
checkpoint = torch.load(url_or_filename, map_location='cpu', weights_only=True)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError('checkpoint url or path is invalid')
|
raise RuntimeError('checkpoint url or path is invalid')
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -78,9 +78,9 @@ def blip_nlvr(pretrained='',**kwargs):
|
||||||
def load_checkpoint(model,url_or_filename):
|
def load_checkpoint(model,url_or_filename):
|
||||||
if is_url(url_or_filename):
|
if is_url(url_or_filename):
|
||||||
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
||||||
checkpoint = torch.load(cached_file, map_location='cpu')
|
checkpoint = torch.load(cached_file, map_location='cpu', weights_only=True)
|
||||||
elif os.path.isfile(url_or_filename):
|
elif os.path.isfile(url_or_filename):
|
||||||
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
checkpoint = torch.load(url_or_filename, map_location='cpu', weights_only=True)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError('checkpoint url or path is invalid')
|
raise RuntimeError('checkpoint url or path is invalid')
|
||||||
state_dict = checkpoint['model']
|
state_dict = checkpoint['model']
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ def init_detection_model(model_name, half=False, device='cuda', model_rootpath=N
|
||||||
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
|
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
|
||||||
|
|
||||||
# TODO: clean pretrained model
|
# TODO: clean pretrained model
|
||||||
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
|
load_net = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
|
||||||
# remove unnecessary 'module.'
|
# remove unnecessary 'module.'
|
||||||
for k, v in deepcopy(load_net).items():
|
for k, v in deepcopy(load_net).items():
|
||||||
if k.startswith('module.'):
|
if k.startswith('module.'):
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ def init_parsing_model(model_name='bisenet', half=False, device='cuda', model_ro
|
||||||
|
|
||||||
model_path = load_file_from_url(
|
model_path = load_file_from_url(
|
||||||
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
|
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
|
||||||
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
|
load_net = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
|
||||||
model.load_state_dict(load_net, strict=True)
|
model.load_state_dict(load_net, strict=True)
|
||||||
model.eval()
|
model.eval()
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
|
||||||
|
|
@ -104,7 +104,7 @@ def load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path):
|
||||||
offload_device = torch.device('cpu')
|
offload_device = torch.device('cpu')
|
||||||
|
|
||||||
use_fp16 = model_management.should_use_fp16(device=load_device)
|
use_fp16 = model_management.should_use_fp16(device=load_device)
|
||||||
ip_state_dict = torch.load(ip_adapter_path, map_location="cpu")
|
ip_state_dict = torch.load(ip_adapter_path, map_location="cpu", weights_only=True)
|
||||||
plus = "latents" in ip_state_dict["image_proj"]
|
plus = "latents" in ip_state_dict["image_proj"]
|
||||||
cross_attention_dim = ip_state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[1]
|
cross_attention_dim = ip_state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[1]
|
||||||
sdxl = cross_attention_dim == 2048
|
sdxl = cross_attention_dim == 2048
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
|
||||||
if clip_stats_path is None:
|
if clip_stats_path is None:
|
||||||
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
|
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
|
||||||
else:
|
else:
|
||||||
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
|
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu", weights_only=True)
|
||||||
self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
|
self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
|
||||||
self.register_buffer("data_std", clip_std[None, :], persistent=False)
|
self.register_buffer("data_std", clip_std[None, :], persistent=False)
|
||||||
self.time_embed = Timestep(timestep_dim)
|
self.time_embed = Timestep(timestep_dim)
|
||||||
|
|
|
||||||
|
|
@ -326,7 +326,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||||
except:
|
except:
|
||||||
embed_out = safe_load_embed_zip(embed_path)
|
embed_out = safe_load_embed_zip(embed_path)
|
||||||
else:
|
else:
|
||||||
embed = torch.load(embed_path, map_location="cpu")
|
embed = torch.load(embed_path, map_location="cpu", weights_only=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
print()
|
print()
|
||||||
|
|
|
||||||
|
|
@ -377,15 +377,15 @@ class VQAutoEncoder(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_path is not None:
|
if model_path is not None:
|
||||||
chkpt = torch.load(model_path, map_location="cpu")
|
chkpt = torch.load(model_path, map_location="cpu", weights_only=True)
|
||||||
if "params_ema" in chkpt:
|
if "params_ema" in chkpt:
|
||||||
self.load_state_dict(
|
self.load_state_dict(
|
||||||
torch.load(model_path, map_location="cpu")["params_ema"]
|
torch.load(model_path, map_location="cpu", weights_only=True)["params_ema"]
|
||||||
)
|
)
|
||||||
logger.info(f"vqgan is loaded from: {model_path} [params_ema]")
|
logger.info(f"vqgan is loaded from: {model_path} [params_ema]")
|
||||||
elif "params" in chkpt:
|
elif "params" in chkpt:
|
||||||
self.load_state_dict(
|
self.load_state_dict(
|
||||||
torch.load(model_path, map_location="cpu")["params"]
|
torch.load(model_path, map_location="cpu", weights_only=True)["params"]
|
||||||
)
|
)
|
||||||
logger.info(f"vqgan is loaded from: {model_path} [params]")
|
logger.info(f"vqgan is loaded from: {model_path} [params]")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -273,8 +273,8 @@ class GFPGANBilinear(nn.Module):
|
||||||
if decoder_load_path:
|
if decoder_load_path:
|
||||||
self.stylegan_decoder.load_state_dict(
|
self.stylegan_decoder.load_state_dict(
|
||||||
torch.load(
|
torch.load(
|
||||||
decoder_load_path, map_location=lambda storage, loc: storage
|
decoder_load_path, map_location=lambda storage, loc: storage,
|
||||||
)["params_ema"]
|
weights_only=True)["params_ema"]
|
||||||
)
|
)
|
||||||
# fix decoder without updating params
|
# fix decoder without updating params
|
||||||
if fix_decoder:
|
if fix_decoder:
|
||||||
|
|
|
||||||
|
|
@ -373,8 +373,8 @@ class GFPGANv1(nn.Module):
|
||||||
if decoder_load_path:
|
if decoder_load_path:
|
||||||
self.stylegan_decoder.load_state_dict(
|
self.stylegan_decoder.load_state_dict(
|
||||||
torch.load(
|
torch.load(
|
||||||
decoder_load_path, map_location=lambda storage, loc: storage
|
decoder_load_path, map_location=lambda storage, loc: storage,
|
||||||
)["params_ema"]
|
weights_only=True)["params_ema"]
|
||||||
)
|
)
|
||||||
# fix decoder without updating params
|
# fix decoder without updating params
|
||||||
if fix_decoder:
|
if fix_decoder:
|
||||||
|
|
|
||||||
|
|
@ -284,8 +284,8 @@ class GFPGANv1Clean(nn.Module):
|
||||||
if decoder_load_path:
|
if decoder_load_path:
|
||||||
self.stylegan_decoder.load_state_dict(
|
self.stylegan_decoder.load_state_dict(
|
||||||
torch.load(
|
torch.load(
|
||||||
decoder_load_path, map_location=lambda storage, loc: storage
|
decoder_load_path, map_location=lambda storage, loc: storage,
|
||||||
)["params_ema"]
|
weights_only=True)["params_ema"]
|
||||||
)
|
)
|
||||||
# fix decoder without updating params
|
# fix decoder without updating params
|
||||||
if fix_decoder:
|
if fix_decoder:
|
||||||
|
|
|
||||||
|
|
@ -231,7 +231,7 @@ def get_previewer(model):
|
||||||
if vae_approx_filename in VAE_approx_models:
|
if vae_approx_filename in VAE_approx_models:
|
||||||
VAE_approx_model = VAE_approx_models[vae_approx_filename]
|
VAE_approx_model = VAE_approx_models[vae_approx_filename]
|
||||||
else:
|
else:
|
||||||
sd = torch.load(vae_approx_filename, map_location='cpu')
|
sd = torch.load(vae_approx_filename, map_location='cpu', weights_only=True)
|
||||||
VAE_approx_model = VAEApprox()
|
VAE_approx_model = VAEApprox()
|
||||||
VAE_approx_model.load_state_dict(sd)
|
VAE_approx_model.load_state_dict(sd)
|
||||||
del sd
|
del sd
|
||||||
|
|
|
||||||
|
|
@ -196,7 +196,7 @@ class InpaintWorker:
|
||||||
|
|
||||||
if inpaint_head_model is None:
|
if inpaint_head_model is None:
|
||||||
inpaint_head_model = InpaintHead()
|
inpaint_head_model = InpaintHead()
|
||||||
sd = torch.load(inpaint_head_model_path, map_location='cpu')
|
sd = torch.load(inpaint_head_model_path, map_location='cpu', weights_only=True)
|
||||||
inpaint_head_model.load_state_dict(sd)
|
inpaint_head_model.load_state_dict(sd)
|
||||||
|
|
||||||
feed = torch.cat([
|
feed = torch.cat([
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ def perform_upscale(img):
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
model_filename = downloading_upscale_model()
|
model_filename = downloading_upscale_model()
|
||||||
sd = torch.load(model_filename)
|
sd = torch.load(model_filename, weights_only=True)
|
||||||
sdo = OrderedDict()
|
sdo = OrderedDict()
|
||||||
for k, v in sd.items():
|
for k, v in sd.items():
|
||||||
sdo[k.replace('residual_block_', 'RDB')] = v
|
sdo[k.replace('residual_block_', 'RDB')] = v
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue