maybe make like this
This commit is contained in:
parent
475527f0ee
commit
31fcbbb76a
|
|
@ -46,12 +46,13 @@ def refresh_controlnets(model_paths):
|
|||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def assert_model_integrity():
|
||||
if model_base.unet_with_lora is None or not hasattr(model_base.unet_with_lora, "model"):
|
||||
print('[Info] Skipping model integrity check: base model is not loaded.')
|
||||
return True
|
||||
error_message = None
|
||||
|
||||
if not isinstance(model_base.unet_with_lora.model, SDXL):
|
||||
raise NotImplementedError('You have selected base model other than SDXL. This is not supported yet.')
|
||||
error_message = 'You have selected base model other than SDXL. This is not supported yet.'
|
||||
|
||||
if error_message is not None:
|
||||
raise NotImplementedError(error_message)
|
||||
|
||||
return True
|
||||
|
||||
|
|
@ -61,11 +62,6 @@ def assert_model_integrity():
|
|||
def refresh_base_model(name, vae_name=None):
|
||||
global model_base
|
||||
|
||||
if name is None or name == 'None':
|
||||
print('[Info] No base model loaded.')
|
||||
model_base = core.StableDiffusionModel()
|
||||
return
|
||||
|
||||
filename = get_file_from_folder_list(name, modules.config.paths_checkpoints)
|
||||
|
||||
vae_filename = None
|
||||
|
|
@ -219,34 +215,17 @@ def set_clip_skip(clip_skip: int):
|
|||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def clear_all_caches():
|
||||
global final_clip
|
||||
if final_clip is not None and hasattr(final_clip, "fcs_cond_cache"):
|
||||
final_clip.fcs_cond_cache = {}
|
||||
else:
|
||||
print("[Info] Skipping cache clear: final_clip is None.")
|
||||
final_clip.fcs_cond_cache = {}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def prepare_text_encoder(async_call=True):
|
||||
if async_call:
|
||||
# TODO: make sure that this is always called in an async way so that users cannot feel it.
|
||||
pass
|
||||
|
||||
assert_model_integrity()
|
||||
|
||||
patchers = []
|
||||
|
||||
if final_clip is not None and hasattr(final_clip, "patcher"):
|
||||
patchers.append(final_clip.patcher)
|
||||
|
||||
if final_expansion is not None and hasattr(final_expansion, "patcher"):
|
||||
patchers.append(final_expansion.patcher)
|
||||
|
||||
if len(patchers) > 0:
|
||||
ldm_patched.modules.model_management.load_models_gpu(patchers)
|
||||
else:
|
||||
print("[Info] No models to load into GPU (no base model).")
|
||||
|
||||
ldm_patched.modules.model_management.load_models_gpu([final_clip.patcher, final_expansion.patcher])
|
||||
return
|
||||
|
||||
|
||||
|
|
@ -288,12 +267,15 @@ def refresh_everything(refiner_model_name, base_model_name, loras,
|
|||
return
|
||||
|
||||
|
||||
refresh_everything(
|
||||
refiner_model_name=modules.config.default_refiner_model_name,
|
||||
base_model_name=modules.config.default_base_model_name,
|
||||
loras=get_enabled_loras(modules.config.default_loras),
|
||||
vae_name=modules.config.default_vae,
|
||||
)
|
||||
if modules.config.default_base_model_name != 'None':
|
||||
refresh_everything(
|
||||
refiner_model_name=modules.config.default_refiner_model_name,
|
||||
base_model_name=modules.config.default_base_model_name,
|
||||
loras=get_enabled_loras(modules.config.default_loras),
|
||||
vae_name=modules.config.default_vae,
|
||||
)
|
||||
else:
|
||||
print('[Startup] Skipping model load (default_base_model_name is "None").')
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
|||
Loading…
Reference in New Issue