diff --git a/ldm_patched/modules/sd.py b/ldm_patched/modules/sd.py index 61260fd0..282f2559 100644 --- a/ldm_patched/modules/sd.py +++ b/ldm_patched/modules/sd.py @@ -427,12 +427,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (ldm_patched.modules.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) -def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, vae_filename=None): +def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, vae_filename_param=None): sd = ldm_patched.modules.utils.load_torch_file(ckpt_path) sd_keys = sd.keys() clip = None clipvision = None vae = None + vae_filename = None model = None model_patcher = None clip_target = None @@ -462,11 +463,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o model.load_model_weights(sd, "model.diffusion_model.") if output_vae: - if vae_filename is None: + if vae_filename_param is None: vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) vae_sd = model_config.process_vae_state_dict(vae_sd) else: - vae_sd = ldm_patched.modules.utils.load_torch_file(vae_filename) + vae_sd = ldm_patched.modules.utils.load_torch_file(vae_filename_param) + vae_filename = vae_filename_param vae = VAE(sd=vae_sd) if output_clip: @@ -488,7 +490,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o print("loaded straight to GPU") model_management.load_model_gpu(model_patcher) - return model_patcher, clip, vae, clipvision + return model_patcher, clip, vae, vae_filename, clipvision def load_unet_state_dict(sd): #load unet in diffusers format diff --git a/modules/core.py b/modules/core.py index c1fa8976..3ca4cc5b 100644 --- a/modules/core.py +++ b/modules/core.py @@ -41,7 +41,7 @@ class StableDiffusionModel: self.clip = clip self.clip_vision = clip_vision self.filename = filename - self.vae_filename = filename + self.vae_filename = vae_filename self.unet_with_lora = unet self.clip_with_lora = clip self.visited_loras = '' @@ -144,9 +144,9 @@ def apply_controlnet(positive, negative, control_net, image, strength, start_per @torch.no_grad() @torch.inference_mode() def load_model(ckpt_filename, vae_filename=None): - unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings, - vae_filename=vae_filename) - return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename) + unet, clip, vae, vae_filename, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings, + vae_filename_param=vae_filename) + return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename, vae_filename=vae_filename) @torch.no_grad() diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index 9f94170c..fcbeab7f 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -71,7 +71,7 @@ def refresh_base_model(name, vae_name=None): if model_base.filename == filename and model_base.vae_filename == vae_filename: return - model_base = core.StableDiffusionModel() + model_base = core.StableDiffusionModel(vae_filename=vae_filename) model_base = core.load_model(filename, vae_filename) print(f'Base model loaded: {model_base.filename}') return