fix: do not reload model when VAE stays the same

This commit is contained in:
Manuel Schmid 2024-05-04 21:12:25 +02:00
parent 15696da9b8
commit 4a3aac09a3
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
3 changed files with 11 additions and 9 deletions

View File

@ -427,12 +427,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (ldm_patched.modules.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, vae_filename=None):
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, vae_filename_param=None):
sd = ldm_patched.modules.utils.load_torch_file(ckpt_path)
sd_keys = sd.keys()
clip = None
clipvision = None
vae = None
vae_filename = None
model = None
model_patcher = None
clip_target = None
@ -462,11 +463,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model.load_model_weights(sd, "model.diffusion_model.")
if output_vae:
if vae_filename is None:
if vae_filename_param is None:
vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
else:
vae_sd = ldm_patched.modules.utils.load_torch_file(vae_filename)
vae_sd = ldm_patched.modules.utils.load_torch_file(vae_filename_param)
vae_filename = vae_filename_param
vae = VAE(sd=vae_sd)
if output_clip:
@ -488,7 +490,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)
return model_patcher, clip, vae, clipvision
return model_patcher, clip, vae, vae_filename, clipvision
def load_unet_state_dict(sd): #load unet in diffusers format

View File

@ -41,7 +41,7 @@ class StableDiffusionModel:
self.clip = clip
self.clip_vision = clip_vision
self.filename = filename
self.vae_filename = filename
self.vae_filename = vae_filename
self.unet_with_lora = unet
self.clip_with_lora = clip
self.visited_loras = ''
@ -144,9 +144,9 @@ def apply_controlnet(positive, negative, control_net, image, strength, start_per
@torch.no_grad()
@torch.inference_mode()
def load_model(ckpt_filename, vae_filename=None):
unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings,
vae_filename=vae_filename)
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename)
unet, clip, vae, vae_filename, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings,
vae_filename_param=vae_filename)
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename, vae_filename=vae_filename)
@torch.no_grad()

View File

@ -71,7 +71,7 @@ def refresh_base_model(name, vae_name=None):
if model_base.filename == filename and model_base.vae_filename == vae_filename:
return
model_base = core.StableDiffusionModel()
model_base = core.StableDiffusionModel(vae_filename=vae_filename)
model_base = core.load_model(filename, vae_filename)
print(f'Base model loaded: {model_base.filename}')
return