diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index 524f426a..0ccb0796 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -46,12 +46,13 @@ def refresh_controlnets(model_paths): @torch.no_grad() @torch.inference_mode() def assert_model_integrity(): - if model_base.unet_with_lora is None or not hasattr(model_base.unet_with_lora, "model"): - print('[Info] Skipping model integrity check: base model is not loaded.') - return True + error_message = None if not isinstance(model_base.unet_with_lora.model, SDXL): - raise NotImplementedError('You have selected base model other than SDXL. This is not supported yet.') + error_message = 'You have selected base model other than SDXL. This is not supported yet.' + + if error_message is not None: + raise NotImplementedError(error_message) return True @@ -61,11 +62,6 @@ def assert_model_integrity(): def refresh_base_model(name, vae_name=None): global model_base - if name is None or name == 'None': - print('[Info] No base model loaded.') - model_base = core.StableDiffusionModel() - return - filename = get_file_from_folder_list(name, modules.config.paths_checkpoints) vae_filename = None @@ -219,34 +215,17 @@ def set_clip_skip(clip_skip: int): @torch.no_grad() @torch.inference_mode() def clear_all_caches(): - global final_clip - if final_clip is not None and hasattr(final_clip, "fcs_cond_cache"): - final_clip.fcs_cond_cache = {} - else: - print("[Info] Skipping cache clear: final_clip is None.") + final_clip.fcs_cond_cache = {} @torch.no_grad() @torch.inference_mode() def prepare_text_encoder(async_call=True): if async_call: + # TODO: make sure that this is always called in an async way so that users cannot feel it. pass - assert_model_integrity() - - patchers = [] - - if final_clip is not None and hasattr(final_clip, "patcher"): - patchers.append(final_clip.patcher) - - if final_expansion is not None and hasattr(final_expansion, "patcher"): - patchers.append(final_expansion.patcher) - - if len(patchers) > 0: - ldm_patched.modules.model_management.load_models_gpu(patchers) - else: - print("[Info] No models to load into GPU (no base model).") - + ldm_patched.modules.model_management.load_models_gpu([final_clip.patcher, final_expansion.patcher]) return @@ -288,12 +267,15 @@ def refresh_everything(refiner_model_name, base_model_name, loras, return -refresh_everything( - refiner_model_name=modules.config.default_refiner_model_name, - base_model_name=modules.config.default_base_model_name, - loras=get_enabled_loras(modules.config.default_loras), - vae_name=modules.config.default_vae, -) +if modules.config.default_base_model_name != 'None': + refresh_everything( + refiner_model_name=modules.config.default_refiner_model_name, + base_model_name=modules.config.default_base_model_name, + loras=get_enabled_loras(modules.config.default_loras), + vae_name=modules.config.default_vae, + ) +else: + print('[Startup] Skipping model load (default_base_model_name is "None").') @torch.no_grad()